• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 
49 using absl::StrCat;
50 
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52     HloInstruction* root_instruction) {
53   int parameter_count = 0;
54   for (auto& instruction : instructions_) {
55     if (instruction->opcode() == HloOpcode::kParameter) {
56       parameter_count++;
57     }
58   }
59   // If root_instruction is not specified use the last added instruction.
60   HloInstruction* root =
61       root_instruction ? root_instruction : last_added_instruction_;
62   CHECK_NE(nullptr, root);
63   return absl::WrapUnique(new HloComputation(
64       name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66 
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68     const string& name, int parameter_count,
69     std::vector<std::unique_ptr<HloInstruction>>* instructions,
70     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71     : name_(NameUniquer::GetSanitizedName(name)),
72       unique_id_(-1),
73       root_instruction_(root_instruction),
74       fusion_instruction_(fusion_instruction),
75       is_fusion_computation_(fusion_instruction != nullptr),
76       custom_call_instruction_(nullptr),
77       is_custom_call_computation_(false) {
78   param_instructions_.resize(parameter_count, nullptr);
79   bool root_found = false;
80   for (auto& instruction : *instructions) {
81     if (instruction->opcode() == HloOpcode::kParameter) {
82       int64_t param_no = instruction->parameter_number();
83       CHECK(param_no >= 0 && param_no < parameter_count)
84           << "\nERROR: invalid parameter number.  Expected [0, "
85           << parameter_count << "), got " << param_no;
86       CHECK(param_instructions_[param_no] == nullptr)
87           << "\nERROR: parameter number " << param_no
88           << " already allocated in this computation";
89       param_instructions_[param_no] = instruction.get();
90     }
91     root_found |= instruction.get() == root_instruction_;
92     AddInstructionInternal(std::move(instruction));
93   }
94   CHECK(root_found)
95       << "\nERROR: root instruction is not present in computation.";
96 }
97 
~HloComputation()98 HloComputation::~HloComputation() {
99   if (fusion_instruction_ != nullptr) {
100     CHECK(fusion_instruction_->fused_instructions_computation() == this);
101     fusion_instruction_->ClearCalledComputations();
102     fusion_instruction_ = nullptr;
103   }
104 }
105 
AddInstruction(std::unique_ptr<HloInstruction> instruction,const std::string & new_name)106 HloInstruction* HloComputation::AddInstruction(
107     std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
108   CHECK(instruction->opcode() != HloOpcode::kParameter)
109       << "Parameter instructions cannot be added to a computation after "
110       << "it has been built";
111   if (!new_name.empty()) {
112     instruction->SetAndSanitizeName(new_name);
113   }
114   return AddInstructionInternal(std::move(instruction));
115 }
116 
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)117 HloInstruction* HloComputation::AddInstructionInternal(
118     std::unique_ptr<HloInstruction> instruction) {
119   if (parent() != nullptr) {
120     instruction->UniquifyName(&parent()->instruction_name_uniquer());
121     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
122   }
123   instruction->set_parent(this);
124   HloInstruction* pinst = instruction.get();
125   instruction_iterators_[pinst] =
126       instructions_.insert(instructions_.end(), std::move(instruction));
127   return pinst;
128 }
129 
AddParameter(std::unique_ptr<HloInstruction> instruction)130 HloInstruction* HloComputation::AddParameter(
131     std::unique_ptr<HloInstruction> instruction) {
132   CHECK(instruction->opcode() == HloOpcode::kParameter);
133   CHECK(IsFusionComputation());
134   CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
135   instruction->set_parent(this);
136   param_instructions_.push_back(instruction.get());
137   AddInstructionInternal(std::move(instruction));
138   return instructions_.back().get();
139 }
140 
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)141 HloInstruction* HloComputation::AddEntryComputationParameter(
142     std::unique_ptr<HloInstruction> instruction) {
143   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
144   CHECK_EQ(instruction->parameter_number(), num_parameters());
145   CHECK(parent()->entry_computation() == this);
146 
147   HloModuleConfig config = parent()->config();
148   config.mutable_entry_computation_layout()->add_parameter_layout(
149       ShapeLayout(instruction->shape()));
150   parent()->set_config(config);
151 
152   instruction->set_parent(this);
153   param_instructions_.push_back(instruction.get());
154   AddInstructionInternal(std::move(instruction));
155 
156   return instructions_.back().get();
157 }
158 
ReplaceEntryComputationParameter(int64_t param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)159 Status HloComputation::ReplaceEntryComputationParameter(
160     int64_t param_no, HloInstruction* old_instruction,
161     std::unique_ptr<HloInstruction> instruction) {
162   CHECK_GE(param_no, 0);
163   CHECK_LT(param_no, param_instructions_.size());
164   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
165   CHECK(parent()->entry_computation() == this);
166 
167   HloModuleConfig config = parent()->config();
168   *config.mutable_entry_computation_layout()->mutable_parameter_layout(
169       param_no) = ShapeLayout(instruction->shape());
170   parent()->set_config(config);
171 
172   instruction->set_parent(this);
173   param_instructions_[param_no] = instruction.get();
174   AddInstructionInternal(std::move(instruction));
175 
176   return ForceRemoveInstruction(old_instruction);
177 }
178 
RemoveParameter(int64_t param_no)179 Status HloComputation::RemoveParameter(int64_t param_no) {
180   CHECK_GE(param_no, 0);
181   CHECK_LT(param_no, param_instructions_.size());
182   CHECK(IsFusionComputation());
183   HloInstruction* param_instruction = param_instructions_[param_no];
184   auto param_instruction_iterator = param_instructions_.begin() + param_no;
185   param_instructions_.erase(param_instruction_iterator);
186   // Throw removed fused parameter instruction away.
187   TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
188 
189   while (param_no < param_instructions_.size()) {
190     param_instruction = param_instructions_[param_no];
191     HloInstruction* new_instr =
192         AddInstructionInternal(HloInstruction::CreateParameter(
193             param_no, param_instruction->shape(), StrCat("param_", param_no)));
194     TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
195     param_instructions_[param_no] = new_instr;
196     TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
197     param_no++;
198   }
199 
200   return Status::OK();
201 }
202 
RemoveUnusedParametersFromFusedComputation()203 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
204   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
205 }
206 
RemoveUnusedParametersFromAnyComputation()207 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
208   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
209 }
210 
RemoveUnusedParametersImpl(bool allow_non_fusion)211 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
212   CHECK(allow_non_fusion || IsFusionComputation());
213   int64_t removed = 0;
214   for (int64_t i = 0; i < param_instructions_.size(); ++i) {
215     HloInstruction* param_instruction = param_instructions_[i];
216     if (param_instruction->user_count() == 0 &&
217         param_instruction != root_instruction()) {
218       TF_RETURN_IF_ERROR(
219           RemoveInstructionImpl(param_instruction, allow_non_fusion));
220       ++removed;
221       continue;
222     }
223 
224     if (removed > 0) {
225       const int64_t param_no = i - removed;
226       HloInstruction* new_instr = AddInstructionInternal(
227           HloInstruction::CreateParameter(param_no, param_instruction->shape(),
228                                           StrCat("param_", param_no)));
229       TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
230       param_instructions_[param_no] = new_instr;
231       TF_RETURN_IF_ERROR(
232           RemoveInstructionImpl(param_instruction, allow_non_fusion));
233     }
234   }
235   param_instructions_.resize(param_instructions_.size() - removed);
236   return Status::OK();
237 }
238 
IsSafelyRemovable(const HloInstruction * instruction)239 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
240   // If the instruction has control predecessors or successors then we cannot
241   // remove the instruction without violating ordering constraints (added, for
242   // example, to avert interference due to buffer aliasing).
243   if (!instruction->control_predecessors().empty() ||
244       !instruction->control_successors().empty()) {
245     return false;
246   }
247 
248   if (instruction->opcode() == HloOpcode::kParameter &&
249       !IsFusionComputation()) {
250     return false;
251   }
252 
253   return true;
254 }
255 
HasSideEffect() const256 bool HloComputation::HasSideEffect() const {
257   for (auto* instruction : instructions()) {
258     if (instruction->HasSideEffect()) {
259       return true;
260     }
261   }
262   return false;
263 }
264 
IsMarkedAsDead(const HloInstruction * inst)265 bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
266   return inst->IsMarkedAsDead();
267 }
268 
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)269 Status HloComputation::RemoveInstructionAndUnusedOperands(
270     HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
271   TF_RET_CHECK(root_instruction() != instruction);
272 
273   TF_RET_CHECK(instruction->user_count() == 0);
274   TF_RET_CHECK(IsSafelyRemovable(instruction))
275       << "Cannot remove instruction: " << instruction->ToString();
276   absl::flat_hash_set<HloInstruction*> removed;
277   std::queue<HloInstruction*> worklist;
278   worklist.push(instruction);
279   while (!worklist.empty()) {
280     HloInstruction* item = worklist.front();
281     worklist.pop();
282 
283     if (removed.contains(item) || item->user_count() != 0 ||
284         item == root_instruction() || !IsSafelyRemovable(item) ||
285         (item->HasSideEffect() && item != instruction)) {
286       continue;
287     }
288     for (int i = 0; i < item->operand_count(); ++i) {
289       worklist.push(item->mutable_operand(i));
290     }
291 
292     if (cleanup) {
293       cleanup(item);
294     }
295     TF_RETURN_IF_ERROR(RemoveInstruction(item));
296     removed.insert(item);
297   }
298   return Status::OK();
299 }
300 
RemoveInstruction(HloInstruction * instruction)301 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
302   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
303 }
304 
ForceRemoveInstruction(HloInstruction * instruction)305 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
306   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
307 }
308 
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)309 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
310                                              bool ignore_safety_check) {
311   VLOG(2) << "Removing instruction " << instruction->name()
312           << " from computation " << name();
313   TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
314       << "cannot remove instruction: " << instruction->ToString();
315   TF_RET_CHECK(root_instruction() != instruction)
316       << "cannot remove root instruction " << instruction->name();
317   TF_RET_CHECK(instruction->user_count() == 0)
318       << "instruction " << instruction->name()
319       << " has users and cannot be removed";
320   TF_RET_CHECK(instruction->control_predecessors().empty())
321       << "instruction " << instruction->name()
322       << " has control predecessors and cannot be removed";
323   TF_RET_CHECK(instruction->control_successors().empty())
324       << "instruction " << instruction->name()
325       << " has control successors and cannot be removed";
326 
327   auto inst_it = instruction_iterators_.find(instruction);
328   TF_RET_CHECK(inst_it != instruction_iterators_.end());
329   (*inst_it->second)->set_parent(nullptr);
330   to_be_deleted_.emplace_back(inst_it->second->release());
331   to_be_deleted_.back()->DetachFromOperandsAndUsers();
332   // Clear all operands to avoid Null operands.
333   to_be_deleted_.back()->RemoveAllOperands();
334   to_be_deleted_.back()->ClearCalledComputations();
335   to_be_deleted_.back()->MarkAsDead();
336   instructions_.erase(inst_it->second);
337   instruction_iterators_.erase(inst_it);
338   return Status::OK();
339 }
340 
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)341 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
342                                           bool accept_different_shape) {
343   // The shape of the root (ignoring layout) is an invariant of the computation
344   // for non-fusion cases.
345   if (!IsFusionComputation() && !accept_different_shape) {
346     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
347                                 root_instruction_->shape()))
348         << new_root_instruction->shape() << " is incompatible with "
349         << root_instruction_->shape();
350   }
351   bool root_found = false;
352   for (auto& instruction : instructions_) {
353     if (new_root_instruction == instruction.get()) {
354       root_found = true;
355       break;
356     }
357   }
358   DCHECK(root_found);
359 
360   if (parent() && parent()->has_entry_computation() &&
361       parent()->entry_computation() == this) {
362     if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
363                                        root_instruction_->shape())) {
364       // Rebuild input output alias config now that we have a new output shape.
365       parent()->input_output_alias_config() =
366           HloInputOutputAliasConfig(new_root_instruction->shape());
367     }
368   }
369 
370   root_instruction_ = new_root_instruction;
371 }
372 
373 namespace {
374 
375 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)376 void ComputeComputationPostOrder(HloComputation* computation,
377                                  absl::flat_hash_set<HloComputation*>* visited,
378                                  std::vector<HloComputation*>* post_order) {
379   if (visited->insert(computation).second) {
380     for (auto* instruction : computation->instructions()) {
381       for (HloComputation* called_computation :
382            instruction->called_computations()) {
383         ComputeComputationPostOrder(called_computation, visited, post_order);
384       }
385     }
386     post_order->push_back(computation);
387   }
388 }
389 
390 }  // namespace
391 
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const392 void HloComputation::ComputeInstructionPostOrder(
393     const HloComputation::ChannelDependencyGroup& channel_dependency_group,
394     std::vector<HloInstruction*>* post_order, HloInstruction* root,
395     absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
396   std::vector<HloInstruction*> dfs_stack;
397   dfs_stack.push_back(root);
398   while (!dfs_stack.empty()) {
399     const auto current = dfs_stack.back();
400     CHECK_EQ(current->parent(), this)
401         << "Instruction " << current->name()
402         << " is not in the current computation (" << name() << ").";
403     auto it = visited->find(current);
404     if (it != visited->end()) {
405       if (it->second == kVisited) {
406         // Already visited.
407         dfs_stack.pop_back();
408         continue;
409       }
410       // Visit this node.
411       CHECK_EQ(kVisiting, it->second);
412       dfs_stack.pop_back();
413       post_order->push_back(current);
414       it->second = kVisited;
415       continue;
416     }
417 
418     visited->insert({current, kVisiting});
419 
420     const auto get_channel_id =
421         [](HloInstruction* inst) -> absl::optional<int64> {
422       switch (inst->opcode()) {
423         case HloOpcode::kRecvDone:
424         case HloOpcode::kAllReduce:
425         case HloOpcode::kAllGather:
426         case HloOpcode::kAllToAll:
427         case HloOpcode::kReduceScatter:
428           return inst->channel_id();
429         default:
430           return absl::nullopt;
431       }
432     };
433 
434     // When adding a predecessor to the dfs_stack, we need to also add its
435     // associated channel dependencies.
436     const auto add_dfs_stack = [&](HloInstruction* inst) {
437       auto channel_id = get_channel_id(inst);
438       if (channel_id && channel_dependency_group.count(*channel_id)) {
439         auto it = channel_dependency_group.find(*channel_id);
440         for (HloInstruction* cinst : it->second) {
441           dfs_stack.emplace_back(cinst);
442         }
443       } else {
444         dfs_stack.emplace_back(inst);
445       }
446     };
447 
448     const auto add_predecessors = [&](HloInstruction* inst) {
449       // Add the operands to the stack in reverse order so the first operand is
450       // processed first. This will produce a more natural ordering and a nicer
451       // result for things like HLO stringification.
452       const auto& operands = inst->operands();
453       for (int64_t i = operands.size() - 1; i >= 0; --i) {
454         add_dfs_stack(operands[i]);
455       }
456 
457       for (HloInstruction* op : inst->control_predecessors()) {
458         add_dfs_stack(op);
459       }
460     };
461 
462     // If the current instruction is a channel instruction, add the dependencies
463     // from all associated instructions of the channel.
464     auto channel_id = get_channel_id(current);
465     if (channel_id && channel_dependency_group.count(*channel_id)) {
466       auto it = channel_dependency_group.find(*channel_id);
467       for (HloInstruction* cinst : it->second) {
468         add_predecessors(cinst);
469       }
470     } else {
471       add_predecessors(current);
472     }
473   }
474 }
475 
476 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const477 HloComputation::ComputeChannelDependencies() const {
478   ChannelDependencyGroup channel_dependency_group;
479   if (parent() && parent()->config().has_static_device_assignment() &&
480       (parent()->config().static_device_assignment().computation_count() == 1 ||
481        parent()->config().use_spmd_partitioning())) {
482     return channel_dependency_group;
483   }
484   for (const auto& instruction : instructions_) {
485     switch (instruction->opcode()) {
486       case HloOpcode::kSend:
487       case HloOpcode::kRecvDone:
488       case HloOpcode::kAllReduce:
489       case HloOpcode::kAllGather:
490       case HloOpcode::kAllToAll:
491       case HloOpcode::kReduceScatter: {
492         auto channel_id = instruction->channel_id();
493         if (channel_id) {
494           channel_dependency_group[channel_id.value()].push_back(
495               instruction.get());
496         }
497         break;
498       }
499       default:
500         break;
501     }
502   }
503   return channel_dependency_group;
504 }
505 
HasOnlyTraceUsers(const HloInstruction * instruction)506 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
507   return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
508     return user->opcode() == HloOpcode::kTrace;
509   });
510 }
511 
MakeInstructionPostOrder() const512 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
513   auto channel_dependency_group = ComputeChannelDependencies();
514   std::vector<HloInstruction*> post_order;
515   post_order.reserve(instruction_count());
516   std::vector<HloInstruction*> trace_instructions;
517   absl::flat_hash_map<HloInstruction*, VisitState> visited;
518   visited.reserve(instruction_count());
519   for (auto& instruction : instructions_) {
520     if (instruction->opcode() == HloOpcode::kTrace) {
521       // Trace instructions aren't handled by the DFS visitor. Add trace
522       // instructions to the post order at the end (necessarily they have no
523       // users).
524       trace_instructions.push_back(instruction.get());
525     } else if (HasOnlyTraceUsers(instruction.get())) {
526       ComputeInstructionPostOrder(channel_dependency_group, &post_order,
527                                   instruction.get(), &visited);
528     }
529   }
530   post_order.insert(post_order.end(), trace_instructions.begin(),
531                     trace_instructions.end());
532   CHECK_EQ(instructions_.size(), post_order.size())
533       << "number of instructions does not match post order size";
534   return post_order;
535 }
536 
MakeEmbeddedComputationsList() const537 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
538     const {
539   absl::flat_hash_set<HloComputation*> visited;
540   std::vector<HloComputation*> post_order;
541 
542   // To avoid special handling of this computation, cast away const of
543   // 'this'. 'this' is immediately removed from the post order after
544   // construction.
545   //
546   // TODO(b/78350259): This violates const-correctness, since while the original
547   // computation is not returned, we still retrieve non-const computations from
548   // a const one. Consider also avoiding const for HloComputation, or review XLA
549   // for const-correctness of non-HloInstruction* types like this.
550   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
551                               &post_order);
552 
553   // We don't want to include this computation in the post order.
554   CHECK_EQ(this, post_order.back());
555   post_order.pop_back();
556 
557   return post_order;
558 }
559 
ToString(const HloPrintOptions & options) const560 string HloComputation::ToString(const HloPrintOptions& options) const {
561   return ToString(options, MakeInstructionPostOrder());
562 }
563 
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const564 string HloComputation::ToString(
565     const HloPrintOptions& options,
566     absl::Span<const HloInstruction* const> instruction_order) const {
567   CHECK_EQ(instruction_order.size(), instruction_count());
568 
569   const string tab(2 * options.indent_amount(), ' ');
570 
571   std::ostringstream s;
572   s << tab;
573 
574   if (!options.is_in_nested_computation()) {
575     if (options.print_percent()) {
576       s << "%";
577     }
578     if (options.print_ids()) {
579       // Exclude entry computation's name because it includes and leads to
580       // non-deterministic fingerprint.
581       s << PrintName(name(), options.print_ids()) << " ";
582     }
583   }
584 
585   if (options.print_program_shape()) {
586     s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
587       << " ";
588   }
589   s << "{\n";
590 
591   // There are instructions which are required to be printed. Additionally, we
592   // print some instructions before and after required ones. The resulting
593   // output has the following format.
594   //
595   //  computation {
596   //    ...
597   //    additional_instructions
598   //    required_instructions
599   //    additional_instructions
600   //    ...
601   //    additional_instructions
602   //    required_instructions
603   //    additional_instructions
604   //    ...
605   //  }
606   std::set<int> instructions_to_print;
607   {
608     // Find all the instructions that should be printed.
609     auto add_instruction = [&instructions_to_print,
610                             &instruction_order](int index) {
611       if (index < 0 || index >= instruction_order.size()) {
612         return;
613       }
614       instructions_to_print.insert(index);
615     };
616 
617     auto add_instructions_arround = [&add_instruction, &options](int index) {
618       for (int i = index - options.leading_and_trailing_instructions_number();
619            i <= index + options.leading_and_trailing_instructions_number();
620            ++i) {
621         add_instruction(i);
622       }
623     };
624 
625     for (int i = 0; i < instruction_order.size(); ++i) {
626       const HloInstruction* instruction = instruction_order[i];
627       CHECK_EQ(this, instruction->parent());
628       if (options.print_instruction(instruction)) {
629         add_instructions_arround(i);
630       }
631     }
632   }
633 
634   {
635     // Print the instructions in this computation.
636     HloPrintOptions new_options = options;
637     new_options.set_indent_amount(options.indent_amount() + 1)
638         .set_is_in_nested_computation(true);
639 
640     const string new_tab(2 * new_options.indent_amount(), ' ');
641 
642     CanonicalNameMap name_map;
643 
644     bool print_prev = true;
645     for (int index = 0; index < instruction_order.size(); ++index) {
646       const HloInstruction* instruction = instruction_order[index];
647       if (instructions_to_print.find(index) != instructions_to_print.end()) {
648         s << new_options.format_instruction(
649                  instruction,
650                  instruction->ToStringWithCanonicalNameMap(new_options,
651                                                            &name_map),
652                  new_options.indent_amount(), instruction == root_instruction_)
653           << "\n";
654         print_prev = true;
655       } else if (print_prev) {
656         s << new_tab << "...\n";
657         print_prev = false;
658       }
659     }
660   }
661 
662   s << tab << "}";
663   return s.str();
664 }
665 
ToProto() const666 HloComputationProto HloComputation::ToProto() const {
667   HloComputationProto proto;
668   CHECK(unique_id_ != -1)
669       << "This computation does not have a valid id. Please make sure the "
670          "computation is inside a module before dumping it.";
671   proto.set_id(unique_id_);
672   proto.set_name(name_);
673   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
674     HloInstructionProto instruction_proto = instruction->ToProto();
675     proto.add_instructions()->Swap(&instruction_proto);
676   }
677   proto.set_root_id(root_instruction()->unique_id());
678   *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
679   return proto;
680 }
681 
682 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)683 HloComputation::CreateFromProto(
684     const HloComputationProto& proto,
685     const absl::flat_hash_map<int64, HloComputation*>& computation_map,
686     bool prohibit_empty_literal) {
687   absl::flat_hash_map<int64, HloInstruction*> instruction_map;
688   absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
689   std::vector<std::unique_ptr<HloInstruction>> instructions;
690   int64_t parameter_count = 0;
691   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
692     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
693                         HloInstruction::CreateFromProto(
694                             instruction_proto, instruction_map, computation_map,
695                             prohibit_empty_literal));
696     if (instruction->opcode() == HloOpcode::kParameter) {
697       parameter_count++;
698     }
699     TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
700     instruction_map[instruction_proto.id()] = instruction.get();
701     to_proto_id[instruction.get()] = instruction_proto.id();
702     instructions.push_back(std::move(instruction));
703   }
704 
705   TF_RET_CHECK(proto.root_id() != -1);
706   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
707   HloInstruction* root = instruction_map.at(proto.root_id());
708 
709   // Sort the instructions in the proto id's order.
710   absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
711                                  const std::unique_ptr<HloInstruction>& b) {
712     return to_proto_id[a.get()] < to_proto_id[b.get()];
713   });
714 
715   TF_RETURN_IF_ERROR([&]() -> Status {
716     std::vector<bool> parameters_seen(parameter_count);
717     int parameters_seen_count = 0;
718     for (auto& instruction : instructions) {
719       if (instruction->opcode() == HloOpcode::kParameter) {
720         int64_t param_no = instruction->parameter_number();
721         TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
722             << "Invalid parameter number.  Expected [0, " << parameter_count
723             << "), got " << param_no;
724         TF_RET_CHECK(!parameters_seen[param_no])
725             << "Parameter number " << param_no
726             << " already allocated in this computation";
727         parameters_seen[param_no] = true;
728         parameters_seen_count++;
729       }
730     }
731     TF_RET_CHECK(parameters_seen_count == parameter_count)
732         << "Not all parameters in range [0, " << parameter_count
733         << ") were referenced";
734     return Status::OK();
735   }());
736 
737   auto computation = absl::WrapUnique(
738       new HloComputation(proto.name(), parameter_count, &instructions, root,
739                          /*fusion_instruction=*/nullptr));
740   computation->unique_id_ = proto.id();
741   return std::move(computation);
742 }
743 
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)744 void HloComputation::FuseInstructionsInto(
745     absl::Span<HloInstruction* const> instructions_to_fuse,
746     HloInstruction* fusion_instruction) {
747   CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
748   HloInstruction* root = instructions_to_fuse.front();
749   TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
750   if (root == root_instruction()) {
751     set_root_instruction(fusion_instruction);
752   }
753   TF_CHECK_OK(RemoveInstruction(root));
754   for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
755     HloInstruction* instruction = instructions_to_fuse[i];
756     fusion_instruction->FuseInstruction(instruction);
757     if (instruction->user_count() == 0) {
758       TF_CHECK_OK(RemoveInstruction(instruction));
759     }
760   }
761 }
762 
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)763 HloInstruction* HloComputation::CreateFusionInstruction(
764     absl::Span<HloInstruction* const> instructions_to_fuse,
765     HloInstruction::FusionKind fusion_kind) {
766   HloInstruction* root = instructions_to_fuse.front();
767   HloInstruction* fusion_instruction = AddInstruction(
768       HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
769   FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
770   return fusion_instruction;
771 }
772 
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)773 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
774     HloInstruction* instruction, ShapeIndex* index,
775     const std::function<
776         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
777                         HloComputation* computation)>& copy_leaf) {
778   if (instruction->shape().IsTuple()) {
779     std::vector<HloInstruction*> elements;
780     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
781          i++) {
782       HloInstruction* gte =
783           AddInstruction(HloInstruction::CreateGetTupleElement(
784               ShapeUtil::GetTupleElementShape(instruction->shape(), i),
785               instruction, i));
786 
787       index->push_back(i);
788       TF_ASSIGN_OR_RETURN(HloInstruction * element,
789                           DeepCopyHelper(gte, index, copy_leaf));
790       elements.push_back(element);
791       index->pop_back();
792     }
793     return AddInstruction(HloInstruction::CreateTuple(elements));
794   }
795   if (instruction->shape().IsToken()) {
796     // Tokens have no on-device representation and cannot be copied. Pass
797     // through transparently.
798     return instruction;
799   }
800 
801   // Array shape.
802   TF_RET_CHECK(instruction->shape().IsArray());
803   return copy_leaf(instruction, *index, this);
804 }
805 
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)806 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
807     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
808     ShapeTree<HloInstruction*>* copies_added) {
809   if (instruction->parent() != this) {
810     return FailedPrecondition(
811         "Can't deep copy instruction %s: instruction is not in computation %s",
812         instruction->name(), name());
813   }
814   if (indices_to_copy != nullptr &&
815       !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
816     return FailedPrecondition(
817         "Can't deep copy instruction %s: given shape tree of indices to copy "
818         "has incompatible shapes: %s vs. %s",
819         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
820         ShapeUtil::HumanString(indices_to_copy->shape()));
821   }
822 
823   ShapeIndex index;
824   auto copy_leaf = [indices_to_copy, copies_added](
825                        HloInstruction* leaf, const ShapeIndex& leaf_index,
826                        HloComputation* computation) {
827     if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
828       HloInstruction* copy = computation->AddInstruction(
829           HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
830       if (copies_added != nullptr) {
831         *copies_added->mutable_element(leaf_index) = copy;
832       }
833       return copy;
834     }
835     // Elements which are not to be copied are passed through
836     // transparently.
837     return leaf;
838   };
839   return DeepCopyHelper(instruction, &index, copy_leaf);
840 }
841 
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)842 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
843     HloInstruction* instruction,
844     const std::function<
845         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
846                         HloComputation* computation)>& copy_leaf) {
847   if (instruction->parent() != this) {
848     return FailedPrecondition(
849         "Can't deep copy instruction %s: instruction is not in computation %s",
850         instruction->name(), name());
851   }
852   ShapeIndex index;
853   return DeepCopyHelper(instruction, &index, copy_leaf);
854 }
855 
ComputeProgramShape(bool include_ids) const856 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
857   ProgramShape program_shape;
858 
859   for (auto* param_instruction : param_instructions_) {
860     *program_shape.add_parameters() = param_instruction->shape();
861     *program_shape.add_parameter_names() =
862         PrintName(param_instruction->name(), include_ids);
863   }
864   *program_shape.mutable_result() = root_instruction_->shape();
865 
866   return program_shape;
867 }
868 
EqualInternal(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values) const869 bool HloComputation::EqualInternal(const HloComputation& other,
870                                    bool is_layout_sensitive,
871                                    bool ignore_channel_id_values) const {
872   if (this == &other) {
873     return true;
874   }
875   absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
876       visited;
877   std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
878 
879   worklist.push_back({root_instruction(), other.root_instruction()});
880 
881   while (!worklist.empty()) {
882     auto pair = worklist.back();
883     worklist.pop_back();
884 
885     if (visited.contains(pair)) {
886       continue;
887     }
888     visited.emplace(pair);
889     // TODO(b/123082518): Avoid recursively invoking Equal because it may
890     // cause a stack overflow with deeply nested subcomputations.
891     auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
892       return true;
893     };
894     auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
895       return a->EqualInternal(*b, is_layout_sensitive,
896                               ignore_channel_id_values);
897     };
898     bool identical_ignoring_operands =
899         ignore_channel_id_values
900             ? pair.first->IdenticalIgnoringChannelIdValues(
901                   *pair.second, operands_eq, comp_eq, is_layout_sensitive)
902             : pair.first->Identical(*pair.second, operands_eq, comp_eq,
903                                     is_layout_sensitive);
904     if (!identical_ignoring_operands) {
905       return false;
906     }
907     for (size_t i = 0; i < pair.first->operands().size(); ++i) {
908       worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
909     }
910   }
911   return true;
912 }
913 
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)914 Status HloComputation::ReplaceWithNewInstruction(
915     HloInstruction* old_instruction,
916     std::unique_ptr<HloInstruction> new_instruction) {
917   return ReplaceInstruction(old_instruction,
918                             AddInstruction(std::move(new_instruction)));
919 }
920 
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)921 Status HloComputation::ReplaceWithNewEntryComputationParameter(
922     HloInstruction* old_instruction,
923     std::unique_ptr<HloInstruction> new_instruction) {
924   return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
925                                                  std::move(new_instruction)));
926 }
927 
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)928 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
929                                           HloInstruction* new_instruction) {
930   TF_RET_CHECK(
931       ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
932       << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
933       << ShapeUtil::HumanString(new_instruction->shape());
934 
935   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
936            << new_instruction->ToString();
937   // Try to add metadata for HLO instructions that are created to replace
938   // existing HLO instructions (e.g. during optimizations). The assumption is
939   // that the old instruction and the new instruction would perform the same
940   // function, and that they would be correlated to the same TF op. This might
941   // not always be correct since HLO optimizations can cross TF op boundaries.
942   // But still this seems to be better than nothing.
943   bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
944                            !old_instruction->metadata().op_name().empty();
945   bool overwrite_pass_id =
946       new_instruction->metadata().op_name().empty() &&
947       new_instruction->metadata().logical_creation_pass_id() == 0 &&
948       old_instruction->metadata().logical_creation_pass_id() != 0;
949   if (overwrite_op_name || overwrite_pass_id) {
950     new_instruction->set_metadata(old_instruction->metadata());
951   }
952   if (new_instruction->frontend_attributes().map().empty()) {
953     new_instruction->set_frontend_attributes(
954         old_instruction->frontend_attributes());
955   }
956 
957   // Like the metadata above, if the user didn't specify any sharding
958   // information on the new instruction we should copy the old sharding
959   // information (if any).
960   if (!new_instruction->has_sharding()) {
961     new_instruction->set_sharding(old_instruction->sharding_ptr());
962   }
963 
964   TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
965   return RemoveInstructionAndUnusedOperands(old_instruction);
966 }
967 
CollectUnreachableRoots() const968 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
969   std::vector<HloInstruction*> unreachable_roots;
970   for (auto* instruction : instructions()) {
971     if (instruction->user_count() == 0 &&
972         instruction->control_successors().empty() &&
973         instruction != root_instruction()) {
974       unreachable_roots.push_back(instruction);
975     }
976   }
977   VLOG(3) << "Unreachable roots:"
978           << absl::StrJoin(unreachable_roots, "\n\t",
979                            [](string* out, const HloInstruction* hlo) {
980                              absl::StrAppend(out, hlo->ToString());
981                            });
982   return unreachable_roots;
983 }
984 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const985 Status HloComputation::AcceptWithOperandOrder(
986     DfsHloVisitor* visitor,
987     const HloInstruction::CompareFunction& operand_order) const {
988   // Visit unreachable roots. Beware that the visitor might delete the currently
989   // visited root, which would invalidate iterators if the unreachable roots
990   // weren't computed ahead of time.
991   for (HloInstruction* root : CollectUnreachableRoots()) {
992     TF_RETURN_IF_ERROR(
993         root->AcceptWithOperandOrder(visitor, operand_order,
994                                      /*call_finish_visit=*/false));
995   }
996   // Visit the computation root instruction last.
997   return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
998                                                     /*call_finish_visit=*/true);
999 }
1000 
Clone(const string & suffix,HloCloneContext * context)1001 std::unique_ptr<HloComputation> HloComputation::Clone(
1002     const string& suffix, HloCloneContext* context) {
1003   return CloneWithReplacements(
1004       /*replacements=*/absl::flat_hash_map<const HloInstruction*,
1005                                            std::unique_ptr<HloInstruction>>(),
1006       /*extra_parameters=*/{}, context, suffix);
1007 }
1008 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)1009 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1010     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1011     HloCloneContext* context, const string& suffix) {
1012   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1013       replacements;
1014   replacements.emplace(std::move(r1));
1015   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1016                                context, suffix);
1017 }
1018 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)1019 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1020     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1021     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1022     HloCloneContext* context, const string& suffix) {
1023   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1024       replacements;
1025   replacements.emplace(std::move(r1));
1026   replacements.emplace(std::move(r2));
1027   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1028                                context, suffix);
1029 }
1030 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)1031 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1032     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1033     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1034     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
1035     HloCloneContext* context, const string& suffix) {
1036   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1037       replacements;
1038   replacements.emplace(std::move(r1));
1039   replacements.emplace(std::move(r2));
1040   replacements.emplace(std::move(r3));
1041   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1042                                context, suffix);
1043 }
1044 
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix,const HloInstruction * new_root)1045 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
1046     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1047         replacements,
1048     absl::Span<const HloInstruction* const> extra_parameters,
1049     HloCloneContext* context, const string& suffix,
1050     const HloInstruction* new_root) {
1051   std::unique_ptr<HloCloneContext> context_ptr;
1052   if (context == nullptr) {
1053     context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
1054     context = context_ptr.get();
1055   }
1056   if (new_root == nullptr) {
1057     new_root = root_instruction();
1058   }
1059 
1060   // Look up instr in the replacements map, and return either the replacement,
1061   // or instr, if the replacement isn't present.
1062   //
1063   // Note: This can return null, indicating that instr should not be present in
1064   // the new computation.
1065   auto replace = [&](const HloInstruction* instr) {
1066     auto it = replacements.find(instr);
1067     return it != replacements.end() ? it->second.get() : instr;
1068   };
1069 
1070   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1071 
1072   // We want to do a postorder walk over [replace(i) for i in instructions_].
1073   // We can't reuse MakeInstructionPostOrder() for this, because that will
1074   // generate a postorder of plain instructions_, and our replacements may
1075   // change the postorder!
1076   //
1077   // The postorder we want here is simpler than what MakeInstructionPostOrder()
1078   // does -- we only care about operand dependencies -- so let's just do it
1079   // ourselves.
1080   std::vector<const HloInstruction*> postorder;
1081   absl::flat_hash_map<const HloInstruction*, VisitState> visited;
1082   for (const auto& instr : instructions_) {
1083     std::vector<const HloInstruction*> dfs_stack;
1084     const HloInstruction* new_instr = replace(instr.get());
1085     if (!new_instr) {
1086       continue;
1087     }
1088     dfs_stack.push_back(new_instr);
1089 
1090     while (!dfs_stack.empty()) {
1091       auto* cur = dfs_stack.back();
1092       auto it = visited.find(cur);
1093       if (it != visited.end()) {
1094         dfs_stack.pop_back();
1095         if (it->second == kVisited) {
1096           continue;
1097         }
1098         CHECK_EQ(it->second, kVisiting);
1099         postorder.push_back(cur);
1100         it->second = kVisited;
1101         continue;
1102       }
1103 
1104       visited.insert({cur, kVisiting});
1105       for (HloInstruction* operand : cur->operands()) {
1106         const HloInstruction* new_operand = replace(operand);
1107         if (new_operand) {
1108           dfs_stack.emplace_back(new_operand);
1109         }
1110       }
1111     }
1112   }
1113 
1114   std::vector<std::unique_ptr<HloInstruction>> instructions;
1115   // First add the extra parameters to 'instructions'.
1116   for (const auto& instr : extra_parameters) {
1117     CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1118         << "Only parameter instructions are allowed in 'extra_parameters'";
1119     instructions.emplace_back(instr->Clone());
1120   }
1121   for (auto instr : postorder) {
1122     std::vector<HloInstruction*> new_operands;
1123     for (auto operand : instr->operands()) {
1124       auto replaced_operand = replace(operand);
1125       CHECK_NE(replaced_operand, nullptr)
1126           << "replacements map tried to eliminate a used instruction "
1127           << operand->ToString() << ", used by " << instr->ToString();
1128       new_operands.push_back(context->GetInstruction(replaced_operand));
1129     }
1130     std::unique_ptr<HloInstruction> new_instr =
1131         instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1132     if (instr->opcode() == HloOpcode::kParameter &&
1133         instr->parameter_replicated_at_leaf_buffers().has_value()) {
1134       new_instr->set_parameter_replicated_at_leaf_buffers(
1135           instr->parameter_replicated_at_leaf_buffers().value());
1136     }
1137     instructions.push_back(std::move(new_instr));
1138   }
1139   Builder builder(name() + "." + suffix);
1140   for (auto& instr : instructions) {
1141     builder.AddInstruction(std::move(instr));
1142   }
1143   auto result = builder.Build(
1144       /*root_instruction=*/context->GetInstruction(replace(new_root)));
1145 
1146   // Clone control dependencies.
1147   for (auto instr : postorder) {
1148     HloInstruction* new_instr = context->GetInstruction(instr);
1149     for (auto successor : instr->control_successors()) {
1150       auto replaced_successor = replace(successor);
1151       // successor may not have been remapped, because it might have been
1152       // removed by the replacements map.
1153       if (replaced_successor != nullptr) {
1154         TF_CHECK_OK(new_instr->AddControlDependencyTo(
1155             context->GetInstruction(replaced_successor)));
1156       }
1157     }
1158   }
1159   context->MapComputation(this, result.get());
1160   return result;
1161 }
1162 
UniquifyName(NameUniquer * name_uniquer)1163 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1164   name_ = name_uniquer->GetUniqueName(name_);
1165 }
1166 
GetInstructionWithName(absl::string_view name)1167 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1168   auto instructions_in_computation = instructions();
1169   auto it = absl::c_find_if(
1170       instructions_in_computation,
1171       [&](HloInstruction* instr) { return instr->name() == name; });
1172   return it == instructions_in_computation.end() ? nullptr : *it;
1173 }
1174 
IsEntryComputation() const1175 bool HloComputation::IsEntryComputation() const {
1176   return parent()->entry_computation() == this;
1177 }
1178 }  // namespace xla
1179