• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 
49 using absl::StrCat;
50 
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52     HloInstruction* root_instruction) {
53   int parameter_count = 0;
54   for (auto& instruction : instructions_) {
55     if (instruction->opcode() == HloOpcode::kParameter) {
56       parameter_count++;
57     }
58   }
59   // If root_instruction is not specified use the last added instruction.
60   HloInstruction* root =
61       root_instruction ? root_instruction : last_added_instruction_;
62   CHECK_NE(nullptr, root);
63   return absl::WrapUnique(new HloComputation(
64       name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66 
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68     const string& name, int parameter_count,
69     std::vector<std::unique_ptr<HloInstruction>>* instructions,
70     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71     : name_(NameUniquer::GetSanitizedName(name)),
72       unique_id_(-1),
73       root_instruction_(root_instruction),
74       fusion_instruction_(fusion_instruction) {
75   param_instructions_.resize(parameter_count, nullptr);
76   bool root_found = false;
77   for (auto& instruction : *instructions) {
78     if (instruction->opcode() == HloOpcode::kParameter) {
79       int64 param_no = instruction->parameter_number();
80       CHECK(param_no >= 0 && param_no < parameter_count)
81           << "\nERROR: invalid parameter number.  Expected [0, "
82           << parameter_count << "), got " << param_no;
83       CHECK(param_instructions_[param_no] == nullptr)
84           << "\nERROR: parameter number " << param_no
85           << " already allocated in this computation";
86       param_instructions_[param_no] = instruction.get();
87     }
88     root_found |= instruction.get() == root_instruction_;
89     AddInstructionInternal(std::move(instruction));
90   }
91   CHECK(root_found)
92       << "\nERROR: root instruction is not present in computation.";
93 }
94 
AddInstruction(std::unique_ptr<HloInstruction> instruction)95 HloInstruction* HloComputation::AddInstruction(
96     std::unique_ptr<HloInstruction> instruction) {
97   CHECK(instruction->opcode() != HloOpcode::kParameter)
98       << "Parameter instructions cannot be added to a computation after "
99       << "it has been built";
100   return AddInstructionInternal(std::move(instruction));
101 }
102 
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)103 HloInstruction* HloComputation::AddInstructionInternal(
104     std::unique_ptr<HloInstruction> instruction) {
105   if (parent() != nullptr) {
106     instruction->UniquifyName(&parent()->instruction_name_uniquer());
107     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
108   }
109   instruction->set_parent(this);
110   HloInstruction* pinst = instruction.get();
111   instruction_iterators_[pinst] =
112       instructions_.insert(instructions_.end(), std::move(instruction));
113   return pinst;
114 }
115 
AddParameter(std::unique_ptr<HloInstruction> instruction)116 HloInstruction* HloComputation::AddParameter(
117     std::unique_ptr<HloInstruction> instruction) {
118   CHECK(instruction->opcode() == HloOpcode::kParameter);
119   CHECK(IsFusionComputation());
120   CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
121   instruction->set_parent(this);
122   param_instructions_.push_back(instruction.get());
123   AddInstructionInternal(std::move(instruction));
124   return instructions_.back().get();
125 }
126 
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)127 HloInstruction* HloComputation::AddEntryComputationParameter(
128     std::unique_ptr<HloInstruction> instruction) {
129   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
130   CHECK_EQ(instruction->parameter_number(), num_parameters());
131   CHECK(parent()->entry_computation() == this);
132 
133   HloModuleConfig config = parent()->config();
134   config.mutable_entry_computation_layout()->add_parameter_layout(
135       ShapeLayout(instruction->shape()));
136   parent()->set_config(config);
137 
138   instruction->set_parent(this);
139   param_instructions_.push_back(instruction.get());
140   AddInstructionInternal(std::move(instruction));
141 
142   return instructions_.back().get();
143 }
144 
ReplaceEntryComputationParameter(int64 param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)145 Status HloComputation::ReplaceEntryComputationParameter(
146     int64 param_no, HloInstruction* old_instruction,
147     std::unique_ptr<HloInstruction> instruction) {
148   CHECK_GE(param_no, 0);
149   CHECK_LT(param_no, param_instructions_.size());
150   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
151   CHECK(parent()->entry_computation() == this);
152 
153   HloModuleConfig config = parent()->config();
154   *config.mutable_entry_computation_layout()->mutable_parameter_layout(
155       param_no) = ShapeLayout(instruction->shape());
156   parent()->set_config(config);
157 
158   instruction->set_parent(this);
159   param_instructions_[param_no] = instruction.get();
160   AddInstructionInternal(std::move(instruction));
161 
162   return ForceRemoveInstruction(old_instruction);
163 }
164 
RemoveParameter(int64 param_no)165 Status HloComputation::RemoveParameter(int64 param_no) {
166   CHECK_GE(param_no, 0);
167   CHECK_LT(param_no, param_instructions_.size());
168   CHECK(IsFusionComputation());
169   HloInstruction* param_instruction = param_instructions_[param_no];
170   auto param_instruction_iterator = param_instructions_.begin() + param_no;
171   param_instructions_.erase(param_instruction_iterator);
172   // Throw removed fused parameter instruction away.
173   TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
174 
175   while (param_no < param_instructions_.size()) {
176     param_instruction = param_instructions_[param_no];
177     HloInstruction* new_instr =
178         AddInstructionInternal(HloInstruction::CreateParameter(
179             param_no, param_instruction->shape(), StrCat("param_", param_no)));
180     TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
181     param_instructions_[param_no] = new_instr;
182     TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
183     param_no++;
184   }
185 
186   return Status::OK();
187 }
188 
RemoveUnusedParametersFromFusedComputation()189 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
190   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
191 }
192 
RemoveUnusedParametersFromAnyComputation()193 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
194   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
195 }
196 
RemoveUnusedParametersImpl(bool allow_non_fusion)197 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
198   CHECK(allow_non_fusion || IsFusionComputation());
199   int64 removed = 0;
200   for (int64 i = 0; i < param_instructions_.size(); ++i) {
201     HloInstruction* param_instruction = param_instructions_[i];
202     if (param_instruction->user_count() == 0 &&
203         param_instruction != root_instruction()) {
204       TF_RETURN_IF_ERROR(
205           RemoveInstructionImpl(param_instruction, allow_non_fusion));
206       ++removed;
207       continue;
208     }
209 
210     if (removed > 0) {
211       const int64 param_no = i - removed;
212       HloInstruction* new_instr = AddInstructionInternal(
213           HloInstruction::CreateParameter(param_no, param_instruction->shape(),
214                                           StrCat("param_", param_no)));
215       TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
216       param_instructions_[param_no] = new_instr;
217       TF_RETURN_IF_ERROR(
218           RemoveInstructionImpl(param_instruction, allow_non_fusion));
219     }
220   }
221   param_instructions_.resize(param_instructions_.size() - removed);
222   return Status::OK();
223 }
224 
IsSafelyRemovable(const HloInstruction * instruction)225 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
226   // If the instruction has control predecessors or successors then we cannot
227   // remove the instruction without violating ordering constraints (added, for
228   // example, to avert interference due to buffer aliasing).
229   if (!instruction->control_predecessors().empty() ||
230       !instruction->control_successors().empty()) {
231     return false;
232   }
233 
234   if (instruction->opcode() == HloOpcode::kParameter &&
235       !IsFusionComputation()) {
236     return false;
237   }
238 
239   return true;
240 }
241 
HasSideEffect() const242 bool HloComputation::HasSideEffect() const {
243   for (auto* instruction : instructions()) {
244     if (instruction->HasSideEffect()) {
245       return true;
246     }
247   }
248   return false;
249 }
250 
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)251 Status HloComputation::RemoveInstructionAndUnusedOperands(
252     HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
253   TF_RET_CHECK(root_instruction() != instruction);
254 
255   TF_RET_CHECK(instruction->user_count() == 0);
256   TF_RET_CHECK(IsSafelyRemovable(instruction))
257       << "Cannot remove instruction: " << instruction->ToString();
258   absl::flat_hash_set<HloInstruction*> removed;
259   std::queue<HloInstruction*> worklist;
260   worklist.push(instruction);
261   while (!worklist.empty()) {
262     HloInstruction* item = worklist.front();
263     worklist.pop();
264 
265     if (removed.contains(item) || item->user_count() != 0 ||
266         item == root_instruction() || !IsSafelyRemovable(item) ||
267         (item->HasSideEffect() && item != instruction)) {
268       continue;
269     }
270     for (int i = 0; i < item->operand_count(); ++i) {
271       worklist.push(item->mutable_operand(i));
272     }
273 
274     if (cleanup) {
275       cleanup(item);
276     }
277     TF_RETURN_IF_ERROR(RemoveInstruction(item));
278     removed.insert(item);
279   }
280   return Status::OK();
281 }
282 
RemoveInstruction(HloInstruction * instruction)283 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
284   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
285 }
286 
ForceRemoveInstruction(HloInstruction * instruction)287 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
288   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
289 }
290 
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)291 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
292                                              bool ignore_safety_check) {
293   VLOG(2) << "Removing instruction " << instruction->name()
294           << " from computation " << name();
295   TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
296       << "cannot remove instruction: " << instruction->ToString();
297   TF_RET_CHECK(root_instruction() != instruction)
298       << "cannot remove root instruction " << instruction->name();
299   TF_RET_CHECK(instruction->user_count() == 0)
300       << "instruction " << instruction->name()
301       << " has users and cannot be removed";
302   TF_RET_CHECK(instruction->control_predecessors().empty())
303       << "instruction " << instruction->name()
304       << " has control predecessors and cannot be removed";
305   TF_RET_CHECK(instruction->control_successors().empty())
306       << "instruction " << instruction->name()
307       << " has control successors and cannot be removed";
308 
309   auto inst_it = instruction_iterators_.find(instruction);
310   TF_RET_CHECK(inst_it != instruction_iterators_.end());
311   (*inst_it->second)->set_parent(nullptr);
312   instructions_.erase(inst_it->second);
313   instruction_iterators_.erase(inst_it);
314   return Status::OK();
315 }
316 
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)317 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
318                                           bool accept_different_shape) {
319   // The shape of the root (ignoring layout) is an invariant of the computation
320   // for non-fusion cases.
321   if (!IsFusionComputation() && !accept_different_shape) {
322     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
323                                 root_instruction_->shape()))
324         << new_root_instruction->shape() << " is incompatible with "
325         << root_instruction_->shape();
326   }
327   bool root_found = false;
328   for (auto& instruction : instructions_) {
329     if (new_root_instruction == instruction.get()) {
330       root_found = true;
331       break;
332     }
333   }
334   DCHECK(root_found);
335 
336   if (parent() && parent()->has_entry_computation() &&
337       parent()->entry_computation() == this) {
338     if (!Shape::Equal()(new_root_instruction->shape(),
339                         root_instruction_->shape())) {
340       // Rebuild input output alias config now that we have a new output shape.
341       parent()->input_output_alias_config() =
342           HloInputOutputAliasConfig(new_root_instruction->shape());
343     }
344   }
345 
346   root_instruction_ = new_root_instruction;
347 }
348 
349 namespace {
350 
351 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)352 void ComputeComputationPostOrder(HloComputation* computation,
353                                  absl::flat_hash_set<HloComputation*>* visited,
354                                  std::vector<HloComputation*>* post_order) {
355   if (visited->insert(computation).second) {
356     for (auto* instruction : computation->instructions()) {
357       for (HloComputation* called_computation :
358            instruction->called_computations()) {
359         ComputeComputationPostOrder(called_computation, visited, post_order);
360       }
361     }
362     post_order->push_back(computation);
363   }
364 }
365 
366 }  // namespace
367 
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const368 void HloComputation::ComputeInstructionPostOrder(
369     const HloComputation::ChannelDependencyGroup& channel_dependency_group,
370     std::vector<HloInstruction*>* post_order, HloInstruction* root,
371     absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
372   std::vector<HloInstruction*> dfs_stack;
373   dfs_stack.push_back(root);
374   while (!dfs_stack.empty()) {
375     const auto current = dfs_stack.back();
376     auto it = visited->find(current);
377     if (it != visited->end()) {
378       if (it->second == kVisited) {
379         // Already visited.
380         dfs_stack.pop_back();
381         continue;
382       }
383       // Visit this node.
384       CHECK_EQ(kVisiting, it->second);
385       dfs_stack.pop_back();
386       post_order->push_back(current);
387       it->second = kVisited;
388       continue;
389     }
390 
391     visited->insert({current, kVisiting});
392 
393     const auto get_channel_id =
394         [](HloInstruction* inst) -> absl::optional<int64> {
395       switch (inst->opcode()) {
396         case HloOpcode::kRecvDone:
397           return inst->channel_id();
398         case HloOpcode::kAllReduce:
399           return inst->channel_id();
400         default:
401           return absl::nullopt;
402       }
403     };
404 
405     // When adding a predecessor to the dfs_stack, we need to also add its
406     // associated channel dependencies.
407     const auto add_dfs_stack = [&](HloInstruction* inst) {
408       auto channel_id = get_channel_id(inst);
409       if (channel_id && channel_dependency_group.count(*channel_id)) {
410         auto it = channel_dependency_group.find(*channel_id);
411         for (HloInstruction* cinst : it->second) {
412           dfs_stack.emplace_back(cinst);
413         }
414       } else {
415         dfs_stack.emplace_back(inst);
416       }
417     };
418 
419     const auto add_predecessors = [&](HloInstruction* inst) {
420       // Add the operands to the stack in reverse order so the first operand is
421       // processed first. This will produce a more natural ordering and a nicer
422       // result for things like HLO stringification.
423       const auto& operands = inst->operands();
424       for (int64 i = operands.size() - 1; i >= 0; --i) {
425         add_dfs_stack(operands[i]);
426       }
427 
428       for (HloInstruction* op : inst->control_predecessors()) {
429         add_dfs_stack(op);
430       }
431     };
432 
433     // If the current instruction is a channel instruction, add the dependencies
434     // from all associated instructions of the channel.
435     auto channel_id = get_channel_id(current);
436     if (channel_id && channel_dependency_group.count(*channel_id)) {
437       auto it = channel_dependency_group.find(*channel_id);
438       for (HloInstruction* cinst : it->second) {
439         add_predecessors(cinst);
440       }
441     } else {
442       add_predecessors(current);
443     }
444   }
445 }
446 
447 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const448 HloComputation::ComputeChannelDependencies() const {
449   ChannelDependencyGroup channel_dependency_group;
450   for (const auto& instruction : instructions_) {
451     switch (instruction->opcode()) {
452       case HloOpcode::kSend:
453       case HloOpcode::kRecvDone:
454       case HloOpcode::kAllReduce: {
455         auto channel_id = instruction->channel_id();
456         if (channel_id) {
457           channel_dependency_group[channel_id.value()].push_back(
458               instruction.get());
459         }
460         break;
461       }
462       default:
463         break;
464     }
465   }
466   return channel_dependency_group;
467 }
468 
HasOnlyTraceUsers(const HloInstruction * instruction)469 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
470   return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
471     return user->opcode() == HloOpcode::kTrace;
472   });
473 }
474 
MakeInstructionPostOrder() const475 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
476   auto channel_dependency_group = ComputeChannelDependencies();
477   std::vector<HloInstruction*> post_order;
478   post_order.reserve(instruction_count());
479   std::vector<HloInstruction*> trace_instructions;
480   absl::flat_hash_map<HloInstruction*, VisitState> visited;
481   visited.reserve(instruction_count());
482   for (auto& instruction : instructions_) {
483     if (instruction->opcode() == HloOpcode::kTrace) {
484       // Trace instructions aren't handled by the DFS visitor. Add trace
485       // instructions to the post order at the end (necessarily they have no
486       // users).
487       trace_instructions.push_back(instruction.get());
488     } else if (HasOnlyTraceUsers(instruction.get())) {
489       ComputeInstructionPostOrder(channel_dependency_group, &post_order,
490                                   instruction.get(), &visited);
491     }
492   }
493   post_order.insert(post_order.end(), trace_instructions.begin(),
494                     trace_instructions.end());
495   CHECK_EQ(instructions_.size(), post_order.size())
496       << "number of instructions does not match post order size";
497   return post_order;
498 }
499 
MakeEmbeddedComputationsList() const500 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
501     const {
502   absl::flat_hash_set<HloComputation*> visited;
503   std::vector<HloComputation*> post_order;
504 
505   // To avoid special handling of this computation, cast away const of
506   // 'this'. 'this' is immediately removed from the post order after
507   // construction.
508   //
509   // TODO(b/78350259): This violates const-correctness, since while the original
510   // computation is not returned, we still retrieve non-const computations from
511   // a const one. Consider also avoiding const for HloComputation, or review XLA
512   // for const-correctness of non-HloInstruction* types like this.
513   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
514                               &post_order);
515 
516   // We don't want to include this computation in the post order.
517   CHECK_EQ(this, post_order.back());
518   post_order.pop_back();
519 
520   return post_order;
521 }
522 
ToString(const HloPrintOptions & options) const523 string HloComputation::ToString(const HloPrintOptions& options) const {
524   return ToString(options, MakeInstructionPostOrder());
525 }
526 
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const527 string HloComputation::ToString(
528     const HloPrintOptions& options,
529     absl::Span<const HloInstruction* const> instruction_order) const {
530   CHECK_EQ(instruction_order.size(), instruction_count());
531 
532   const string tab(2 * options.indent_amount(), ' ');
533 
534   std::ostringstream s;
535   s << tab;
536 
537   if (!options.is_in_nested_computation()) {
538     if (options.print_percent()) {
539       s << "%";
540     }
541     s << PrintName(name(), options.print_ids()) << " ";
542   }
543 
544   if (options.print_program_shape()) {
545     s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
546       << " ";
547   }
548   s << "{\n";
549 
550   // There are instructions which are required to be printed. Additionally, we
551   // print some instructions before and after required ones. The resulting
552   // output has the following format.
553   //
554   //  computation {
555   //    ...
556   //    additional_instructions
557   //    required_instructions
558   //    additional_instructions
559   //    ...
560   //    additional_instructions
561   //    required_instructions
562   //    additional_instructions
563   //    ...
564   //  }
565   std::set<int> instructions_to_print;
566   {
567     // Find all the instructions that should be printed.
568     auto add_instruction = [&instructions_to_print,
569                             &instruction_order](int index) {
570       if (index < 0 || index >= instruction_order.size()) {
571         return;
572       }
573       instructions_to_print.insert(index);
574     };
575 
576     auto add_instructions_arround = [&add_instruction, &options](int index) {
577       for (int i = index - options.leading_and_trailing_instructions_number();
578            i <= index + options.leading_and_trailing_instructions_number();
579            ++i) {
580         add_instruction(i);
581       }
582     };
583 
584     for (int i = 0; i < instruction_order.size(); ++i) {
585       const HloInstruction* instruction = instruction_order[i];
586       CHECK_EQ(this, instruction->parent());
587       if (options.print_instruction(instruction)) {
588         add_instructions_arround(i);
589       }
590     }
591   }
592 
593   {
594     // Print the instructions in this computation.
595     HloPrintOptions new_options = options;
596     new_options.set_indent_amount(options.indent_amount() + 1)
597         .set_is_in_nested_computation(true);
598 
599     const string new_tab(2 * new_options.indent_amount(), ' ');
600 
601     CanonicalNameMap name_map;
602 
603     bool print_prev = true;
604     for (int index = 0; index < instruction_order.size(); ++index) {
605       const HloInstruction* instruction = instruction_order[index];
606       if (instructions_to_print.find(index) != instructions_to_print.end()) {
607         s << new_options.format_instruction(
608                  instruction,
609                  instruction->ToStringWithCanonicalNameMap(new_options,
610                                                            &name_map),
611                  new_options.indent_amount(), instruction == root_instruction_)
612           << "\n";
613         print_prev = true;
614       } else if (print_prev) {
615         s << new_tab << "...\n";
616         print_prev = false;
617       }
618     }
619   }
620 
621   s << tab << "}";
622   return s.str();
623 }
624 
ToProto() const625 HloComputationProto HloComputation::ToProto() const {
626   HloComputationProto proto;
627   CHECK(unique_id_ != -1)
628       << "This computation does not have a valid id. Please make sure the "
629          "computation is inside a module before dumping it.";
630   proto.set_id(unique_id_);
631   proto.set_name(name_);
632   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
633     HloInstructionProto instruction_proto = instruction->ToProto();
634     proto.add_instructions()->Swap(&instruction_proto);
635   }
636   proto.set_root_id(root_instruction()->unique_id());
637   *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
638   return proto;
639 }
640 
641 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)642 HloComputation::CreateFromProto(
643     const HloComputationProto& proto,
644     const absl::flat_hash_map<int64, HloComputation*>& computation_map,
645     bool prohibit_empty_literal) {
646   absl::flat_hash_map<int64, HloInstruction*> instruction_map;
647   absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
648   std::vector<std::unique_ptr<HloInstruction>> instructions;
649   int64 parameter_count = 0;
650   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
651     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
652                         HloInstruction::CreateFromProto(
653                             instruction_proto, instruction_map, computation_map,
654                             prohibit_empty_literal));
655     if (instruction->opcode() == HloOpcode::kParameter) {
656       parameter_count++;
657     }
658     TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
659     instruction_map[instruction_proto.id()] = instruction.get();
660     to_proto_id[instruction.get()] = instruction_proto.id();
661     instructions.push_back(std::move(instruction));
662   }
663 
664   TF_RET_CHECK(proto.root_id() != -1);
665   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
666   HloInstruction* root = instruction_map.at(proto.root_id());
667 
668   // Sort the instructions in the proto id's order.
669   absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
670                                  const std::unique_ptr<HloInstruction>& b) {
671     return to_proto_id[a.get()] < to_proto_id[b.get()];
672   });
673 
674   TF_RETURN_IF_ERROR([&]() -> Status {
675     std::vector<bool> parameters_seen(parameter_count);
676     int parameters_seen_count = 0;
677     for (auto& instruction : instructions) {
678       if (instruction->opcode() == HloOpcode::kParameter) {
679         int64 param_no = instruction->parameter_number();
680         TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
681             << "Invalid parameter number.  Expected [0, " << parameter_count
682             << "), got " << param_no;
683         TF_RET_CHECK(!parameters_seen[param_no])
684             << "Parameter number " << param_no
685             << " already allocated in this computation";
686         parameters_seen[param_no] = true;
687         parameters_seen_count++;
688       }
689     }
690     TF_RET_CHECK(parameters_seen_count == parameter_count)
691         << "Not all parameters in range [0, " << parameter_count
692         << ") were referenced";
693     return Status::OK();
694   }());
695 
696   auto computation = absl::WrapUnique(
697       new HloComputation(proto.name(), parameter_count, &instructions, root,
698                          /*fusion_instruction=*/nullptr));
699   computation->unique_id_ = proto.id();
700   return std::move(computation);
701 }
702 
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)703 void HloComputation::FuseInstructionsInto(
704     absl::Span<HloInstruction* const> instructions_to_fuse,
705     HloInstruction* fusion_instruction) {
706   CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
707   HloInstruction* root = instructions_to_fuse.front();
708   TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
709   if (root == root_instruction()) {
710     set_root_instruction(fusion_instruction);
711   }
712   TF_CHECK_OK(RemoveInstruction(root));
713   for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
714     HloInstruction* instruction = instructions_to_fuse[i];
715     fusion_instruction->FuseInstruction(instruction);
716     if (instruction->user_count() == 0) {
717       TF_CHECK_OK(RemoveInstruction(instruction));
718     }
719   }
720 }
721 
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)722 HloInstruction* HloComputation::CreateFusionInstruction(
723     absl::Span<HloInstruction* const> instructions_to_fuse,
724     HloInstruction::FusionKind fusion_kind) {
725   HloInstruction* root = instructions_to_fuse.front();
726   HloInstruction* fusion_instruction = AddInstruction(
727       HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
728   FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
729   return fusion_instruction;
730 }
731 
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)732 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
733     HloInstruction* instruction, ShapeIndex* index,
734     const std::function<
735         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
736                         HloComputation* computation)>& copy_leaf) {
737   if (instruction->shape().IsTuple()) {
738     std::vector<HloInstruction*> elements;
739     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
740          i++) {
741       HloInstruction* gte =
742           AddInstruction(HloInstruction::CreateGetTupleElement(
743               ShapeUtil::GetTupleElementShape(instruction->shape(), i),
744               instruction, i));
745 
746       index->push_back(i);
747       TF_ASSIGN_OR_RETURN(HloInstruction * element,
748                           DeepCopyHelper(gte, index, copy_leaf));
749       elements.push_back(element);
750       index->pop_back();
751     }
752     return AddInstruction(HloInstruction::CreateTuple(elements));
753   }
754   if (instruction->shape().IsToken()) {
755     // Tokens have no on-device representation and cannot be copied. Pass
756     // through transparently.
757     return instruction;
758   }
759 
760   // Array shape.
761   TF_RET_CHECK(instruction->shape().IsArray());
762   return copy_leaf(instruction, *index, this);
763 }
764 
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)765 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
766     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
767     ShapeTree<HloInstruction*>* copies_added) {
768   if (instruction->parent() != this) {
769     return FailedPrecondition(
770         "Can't deep copy instruction %s: instruction is not in computation %s",
771         instruction->name(), name());
772   }
773   if (indices_to_copy != nullptr &&
774       !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
775     return FailedPrecondition(
776         "Can't deep copy instruction %s: given shape tree of indices to copy "
777         "has incompatible shapes: %s vs. %s",
778         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
779         ShapeUtil::HumanString(indices_to_copy->shape()));
780   }
781 
782   ShapeIndex index;
783   auto copy_leaf = [indices_to_copy, copies_added](
784                        HloInstruction* leaf, const ShapeIndex& leaf_index,
785                        HloComputation* computation) {
786     if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
787       HloInstruction* copy = computation->AddInstruction(
788           HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
789       if (copies_added != nullptr) {
790         *copies_added->mutable_element(leaf_index) = copy;
791       }
792       return copy;
793     }
794     // Elements which are not to be copied are passed through
795     // transparently.
796     return leaf;
797   };
798   return DeepCopyHelper(instruction, &index, copy_leaf);
799 }
800 
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)801 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
802     HloInstruction* instruction,
803     const std::function<
804         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
805                         HloComputation* computation)>& copy_leaf) {
806   if (instruction->parent() != this) {
807     return FailedPrecondition(
808         "Can't deep copy instruction %s: instruction is not in computation %s",
809         instruction->name(), name());
810   }
811   ShapeIndex index;
812   return DeepCopyHelper(instruction, &index, copy_leaf);
813 }
814 
ComputeProgramShape(bool include_ids) const815 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
816   ProgramShape program_shape;
817 
818   for (auto* param_instruction : param_instructions_) {
819     *program_shape.add_parameters() = param_instruction->shape();
820     *program_shape.add_parameter_names() =
821         PrintName(param_instruction->name(), include_ids);
822   }
823   *program_shape.mutable_result() = root_instruction_->shape();
824 
825   return program_shape;
826 }
827 
Equal(const HloComputation & other,bool is_layout_sensitive) const828 bool HloComputation::Equal(const HloComputation& other,
829                            bool is_layout_sensitive) const {
830   if (this == &other) {
831     return true;
832   }
833   absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
834       visited;
835   std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
836 
837   worklist.push_back({root_instruction(), other.root_instruction()});
838 
839   while (!worklist.empty()) {
840     auto pair = worklist.back();
841     worklist.pop_back();
842 
843     if (visited.contains(pair)) {
844       continue;
845     }
846     visited.emplace(pair);
847     // TODO(b/123082518): Avoid recursively invoking == because it may
848     // cause a stack overflow with deeply nested subcomputations.
849     bool identical_ignoring_operands = pair.first->Identical(
850         *pair.second,
851         [](const HloInstruction*, const HloInstruction*) { return true; },
852         [](const HloComputation* a, const HloComputation* b) {
853           return *a == *b;
854         },
855         is_layout_sensitive);
856     if (!identical_ignoring_operands) {
857       return false;
858     }
859     for (size_t i = 0; i < pair.first->operands().size(); ++i) {
860       worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
861     }
862   }
863   return true;
864 }
865 
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)866 Status HloComputation::ReplaceWithNewInstruction(
867     HloInstruction* old_instruction,
868     std::unique_ptr<HloInstruction> new_instruction) {
869   return ReplaceInstruction(old_instruction,
870                             AddInstruction(std::move(new_instruction)));
871 }
872 
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)873 Status HloComputation::ReplaceWithNewEntryComputationParameter(
874     HloInstruction* old_instruction,
875     std::unique_ptr<HloInstruction> new_instruction) {
876   return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
877                                                  std::move(new_instruction)));
878 }
879 
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)880 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
881                                           HloInstruction* new_instruction) {
882   TF_RET_CHECK(
883       ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
884       << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
885       << ShapeUtil::HumanString(new_instruction->shape());
886 
887   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
888            << new_instruction->ToString();
889   // Try to add metadata for HLO instructions that are created to replace
890   // existing HLO instructions (e.g. during optimizations). The assumption is
891   // that the old instruction and the new instruction would perform the same
892   // function, and that they would be correlated to the same TF op. This might
893   // not always be correct since HLO optimizations can cross TF op boundaries.
894   // But still this seems to be better than nothing.
895   if (new_instruction->metadata().op_name().empty()) {
896     new_instruction->set_metadata(old_instruction->metadata());
897   }
898   if (new_instruction->frontend_attributes().map().empty()) {
899     new_instruction->set_frontend_attributes(
900         old_instruction->frontend_attributes());
901   }
902 
903   // Like the metadata above, if the user didn't specify any sharding
904   // information on the new instruction we should copy the old sharding
905   // information (if any).
906   if (!new_instruction->has_sharding()) {
907     new_instruction->set_sharding(old_instruction->sharding_ptr());
908   }
909 
910   TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
911   return RemoveInstructionAndUnusedOperands(old_instruction);
912 }
913 
CollectUnreachableRoots() const914 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
915   std::vector<HloInstruction*> unreachable_roots;
916   for (auto* instruction : instructions()) {
917     if (instruction->user_count() == 0 &&
918         instruction->control_successors().empty() &&
919         instruction != root_instruction()) {
920       unreachable_roots.push_back(instruction);
921     }
922   }
923   VLOG(3) << "Unreachable roots:"
924           << absl::StrJoin(unreachable_roots, "\n\t",
925                            [](string* out, const HloInstruction* hlo) {
926                              absl::StrAppend(out, hlo->ToString());
927                            });
928   return unreachable_roots;
929 }
930 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const931 Status HloComputation::AcceptWithOperandOrder(
932     DfsHloVisitor* visitor,
933     const HloInstruction::CompareFunction& operand_order) const {
934   // Visit unreachable roots. Beware that the visitor might delete the currently
935   // visited root, which would invalidate iterators if the unreachable roots
936   // weren't computed ahead of time.
937   for (HloInstruction* root : CollectUnreachableRoots()) {
938     TF_RETURN_IF_ERROR(
939         root->AcceptWithOperandOrder(visitor, operand_order,
940                                      /*call_finish_visit=*/false));
941   }
942   // Visit the computation root instruction last.
943   return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
944                                                     /*call_finish_visit=*/true);
945 }
946 
Clone(const string & suffix,HloCloneContext * context)947 std::unique_ptr<HloComputation> HloComputation::Clone(
948     const string& suffix, HloCloneContext* context) {
949   return CloneWithReplacements(
950       /*replacements=*/absl::flat_hash_map<const HloInstruction*,
951                                            std::unique_ptr<HloInstruction>>(),
952       /*extra_parameters=*/{}, context, suffix);
953 }
954 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)955 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
956     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
957     HloCloneContext* context, const string& suffix) {
958   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
959       replacements;
960   replacements.emplace(std::move(r1));
961   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
962                                context, suffix);
963 }
964 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)965 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
966     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
967     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
968     HloCloneContext* context, const string& suffix) {
969   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
970       replacements;
971   replacements.emplace(std::move(r1));
972   replacements.emplace(std::move(r2));
973   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
974                                context, suffix);
975 }
976 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)977 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
978     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
979     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
980     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
981     HloCloneContext* context, const string& suffix) {
982   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
983       replacements;
984   replacements.emplace(std::move(r1));
985   replacements.emplace(std::move(r2));
986   replacements.emplace(std::move(r3));
987   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
988                                context, suffix);
989 }
990 
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix)991 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
992     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
993         replacements,
994     absl::Span<const HloInstruction* const> extra_parameters,
995     HloCloneContext* context, const string& suffix) {
996   std::unique_ptr<HloCloneContext> context_ptr;
997   if (context == nullptr) {
998     context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
999     context = context_ptr.get();
1000   }
1001 
1002   // Look up instr in the replacements map, and return either the replacement,
1003   // or instr, if the replacement isn't present.
1004   //
1005   // Note: This can return null, indicating that instr should not be present in
1006   // the new computation.
1007   auto replace = [&](HloInstruction* instr) {
1008     auto it = replacements.find(instr);
1009     if (it == replacements.end()) {
1010       return instr;
1011     }
1012     return it->second.get();
1013   };
1014 
1015   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1016 
1017   // We want to do a postorder walk over [replace(i) for i in instructions_].
1018   // We can't reuse MakeInstructionPostOrder() for this, because that will
1019   // generate a postorder of plain instructions_, and our replacements may
1020   // change the postorder!
1021   //
1022   // The postorder we want here is simpler than what MakeInstructionPostOrder()
1023   // does -- we only care about operand dependencies -- so let's just do it
1024   // ourselves.
1025   std::vector<HloInstruction*> postorder;
1026   absl::flat_hash_map<HloInstruction*, VisitState> visited;
1027   for (const auto& instr : instructions_) {
1028     std::vector<HloInstruction*> dfs_stack;
1029     HloInstruction* new_instr = replace(instr.get());
1030     if (!new_instr) {
1031       continue;
1032     }
1033     dfs_stack.push_back(new_instr);
1034 
1035     while (!dfs_stack.empty()) {
1036       auto* cur = dfs_stack.back();
1037       auto it = visited.find(cur);
1038       if (it != visited.end()) {
1039         dfs_stack.pop_back();
1040         if (it->second == kVisited) {
1041           continue;
1042         }
1043         CHECK_EQ(it->second, kVisiting);
1044         postorder.push_back(cur);
1045         it->second = kVisited;
1046         continue;
1047       }
1048 
1049       visited.insert({cur, kVisiting});
1050       for (HloInstruction* operand : cur->operands()) {
1051         HloInstruction* new_operand = replace(operand);
1052         if (new_operand) {
1053           dfs_stack.emplace_back(new_operand);
1054         }
1055       }
1056     }
1057   }
1058 
1059   std::vector<std::unique_ptr<HloInstruction>> instructions;
1060   // First add the extra parameters to 'instructions'.
1061   for (const auto& instr : extra_parameters) {
1062     CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1063         << "Only parameter instructions are allowed in 'extra_parameters'";
1064     instructions.emplace_back(instr->Clone());
1065   }
1066   for (auto instr : postorder) {
1067     std::vector<HloInstruction*> new_operands;
1068     for (auto operand : instr->operands()) {
1069       auto replaced_operand = replace(operand);
1070       CHECK_NE(replaced_operand, nullptr)
1071           << "replacements map tried to eliminate a used instruction "
1072           << operand->ToString() << ", used by " << instr->ToString();
1073       new_operands.push_back(context->GetInstruction(replaced_operand));
1074     }
1075     std::unique_ptr<HloInstruction> new_instr =
1076         instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1077     if (instr->opcode() == HloOpcode::kParameter &&
1078         instr->parameter_replicated_at_leaf_buffers().has_value()) {
1079       new_instr->set_parameter_replicated_at_leaf_buffers(
1080           instr->parameter_replicated_at_leaf_buffers().value());
1081     }
1082     instructions.push_back(std::move(new_instr));
1083   }
1084   Builder builder(name() + "." + suffix);
1085   for (auto& instr : instructions) {
1086     builder.AddInstruction(std::move(instr));
1087   }
1088   auto result = builder.Build(
1089       /*root_instruction=*/context->GetInstruction(
1090           replace(root_instruction())));
1091 
1092   // Clone control dependencies.
1093   for (auto instr : postorder) {
1094     HloInstruction* new_instr = context->GetInstruction(instr);
1095     for (auto successor : instr->control_successors()) {
1096       auto replaced_successor = replace(successor);
1097       // successor may not have been remapped, because it might have been
1098       // removed by the replacements map.
1099       if (replaced_successor != nullptr) {
1100         TF_CHECK_OK(new_instr->AddControlDependencyTo(
1101             context->GetInstruction(replaced_successor)));
1102       }
1103     }
1104   }
1105   context->MapComputation(this, result.get());
1106   return result;
1107 }
1108 
UniquifyName(NameUniquer * name_uniquer)1109 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1110   name_ = name_uniquer->GetUniqueName(name_);
1111 }
1112 
GetInstructionWithName(absl::string_view name)1113 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1114   auto instructions_in_computation = instructions();
1115   auto it = absl::c_find_if(
1116       instructions_in_computation,
1117       [&](HloInstruction* instr) { return instr->name() == name; });
1118   return it == instructions_in_computation.end() ? nullptr : *it;
1119 }
1120 
IsEntryComputation() const1121 bool HloComputation::IsEntryComputation() const {
1122   return parent()->entry_computation() == this;
1123 }
1124 }  // namespace xla
1125