• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <functional>
22 #include <list>
23 #include <memory>
24 #include <optional>
25 #include <queue>
26 #include <set>
27 #include <sstream>
28 #include <string>
29 #include <utility>
30 #include <vector>
31 
32 #include "absl/algorithm/container.h"
33 #include "absl/container/flat_hash_map.h"
34 #include "absl/container/flat_hash_set.h"
35 #include "absl/strings/numbers.h"
36 #include "absl/strings/str_cat.h"
37 #include "absl/strings/str_join.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/map_util.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
41 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
47 #include "tensorflow/compiler/xla/shape_util.h"
48 #include "tensorflow/compiler/xla/status_macros.h"
49 #include "tensorflow/compiler/xla/types.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/platform/logging.h"
54 
55 namespace xla {
56 
57 using absl::StrCat;
58 
Build(HloInstruction * root_instruction)59 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
60     HloInstruction* root_instruction) {
61   int parameter_count = 0;
62   for (auto& instruction : instructions_) {
63     if (instruction->opcode() == HloOpcode::kParameter) {
64       parameter_count++;
65     }
66   }
67   // If root_instruction is not specified use the last added instruction.
68   HloInstruction* root =
69       root_instruction ? root_instruction : last_added_instruction_;
70   CHECK_NE(nullptr, root);
71   return absl::WrapUnique(new HloComputation(
72       name_, parameter_count, &instructions_, root, fusion_instruction_));
73 }
74 
HloComputation(const std::string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)75 HloComputation::HloComputation(
76     const std::string& name, int parameter_count,
77     std::vector<std::unique_ptr<HloInstruction>>* instructions,
78     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
79     : name_(NameUniquer::GetSanitizedName(name)),
80       unique_id_(-1),
81       root_instruction_(root_instruction),
82       fusion_instruction_(fusion_instruction),
83       is_fusion_computation_(fusion_instruction != nullptr),
84       custom_call_instruction_(nullptr),
85       is_custom_call_computation_(false) {
86   param_instructions_.resize(parameter_count, nullptr);
87   bool root_found = false;
88   for (auto& instruction : *instructions) {
89     if (instruction->opcode() == HloOpcode::kParameter) {
90       int64_t param_no = instruction->parameter_number();
91       CHECK(param_no >= 0 && param_no < parameter_count)
92           << "\nERROR: invalid parameter number.  Expected [0, "
93           << parameter_count << "), got " << param_no;
94       CHECK(param_instructions_[param_no] == nullptr)
95           << "\nERROR: parameter number " << param_no
96           << " already allocated in this computation";
97       param_instructions_[param_no] = instruction.get();
98     }
99     root_found |= instruction.get() == root_instruction_;
100     AddInstructionInternal(std::move(instruction));
101   }
102   CHECK(root_found)
103       << "\nERROR: root instruction is not present in computation.";
104 }
105 
~HloComputation()106 HloComputation::~HloComputation() {
107   if (fusion_instruction_ != nullptr) {
108     CHECK(fusion_instruction_->fused_instructions_computation() == this);
109     fusion_instruction_->ClearCalledComputations();
110     fusion_instruction_ = nullptr;
111   }
112   if (IsAsyncComputation()) {
113     for (auto* async_instr : async_instructions_) {
114       CHECK(async_instr->async_wrapped_computation() == this);
115       async_instr->ClearCalledComputations();
116     }
117     async_instructions_.clear();
118   }
119 }
120 
AddInstruction(std::unique_ptr<HloInstruction> instruction,const std::string & new_name)121 HloInstruction* HloComputation::AddInstruction(
122     std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
123   CHECK(instruction->opcode() != HloOpcode::kParameter)
124       << "Parameter instructions cannot be added to a computation after "
125       << "it has been built";
126   if (!new_name.empty()) {
127     instruction->SetAndSanitizeName(new_name);
128   }
129   return AddInstructionInternal(std::move(instruction));
130 }
131 
AddInstruction(std::unique_ptr<HloInstruction> instruction,const OpMetadata * metadata)132 HloInstruction* HloComputation::AddInstruction(
133     std::unique_ptr<HloInstruction> instruction, const OpMetadata* metadata) {
134   if (metadata != nullptr) {
135     instruction->set_metadata(*metadata);
136   }
137   return AddInstruction(std::move(instruction));
138 }
139 
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)140 HloInstruction* HloComputation::AddInstructionInternal(
141     std::unique_ptr<HloInstruction> instruction) {
142   if (parent() != nullptr) {
143     instruction->UniquifyName(&parent()->instruction_name_uniquer());
144     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
145   }
146   instruction->set_parent(this);
147   HloInstruction* pinst = instruction.get();
148   instruction_iterators_[pinst] =
149       instructions_.insert(instructions_.end(), std::move(instruction));
150   return pinst;
151 }
152 
AddParameter(std::unique_ptr<HloInstruction> instruction)153 HloInstruction* HloComputation::AddParameter(
154     std::unique_ptr<HloInstruction> instruction) {
155   CHECK(instruction->opcode() == HloOpcode::kParameter);
156   CHECK(!IsFusionComputation() ||
157         fusion_instruction_->operand_count() == param_instructions_.size());
158   instruction->set_parent(this);
159   param_instructions_.push_back(instruction.get());
160   AddInstructionInternal(std::move(instruction));
161   return instructions_.back().get();
162 }
163 
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)164 HloInstruction* HloComputation::AddEntryComputationParameter(
165     std::unique_ptr<HloInstruction> instruction) {
166   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
167   CHECK_EQ(instruction->parameter_number(), num_parameters());
168   CHECK(parent()->entry_computation() == this);
169 
170   HloModuleConfig config = parent()->config();
171   config.mutable_entry_computation_layout()->add_parameter_layout(
172       ShapeLayout(instruction->shape()));
173   parent()->set_config(config);
174 
175   instruction->set_parent(this);
176   param_instructions_.push_back(instruction.get());
177   AddInstructionInternal(std::move(instruction));
178 
179   return instructions_.back().get();
180 }
181 
ReplaceEntryComputationParameter(int64_t param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)182 Status HloComputation::ReplaceEntryComputationParameter(
183     int64_t param_no, HloInstruction* old_instruction,
184     std::unique_ptr<HloInstruction> instruction) {
185   CHECK_GE(param_no, 0);
186   CHECK_LT(param_no, param_instructions_.size());
187   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
188   CHECK(parent()->entry_computation() == this);
189 
190   HloModuleConfig config = parent()->config();
191   *config.mutable_entry_computation_layout()->mutable_parameter_layout(
192       param_no) = ShapeLayout(instruction->shape());
193   parent()->set_config(config);
194 
195   instruction->set_parent(this);
196   param_instructions_[param_no] = instruction.get();
197   AddInstructionInternal(std::move(instruction));
198 
199   return ForceRemoveInstruction(old_instruction);
200 }
201 
RemoveParameter(int64_t param_no)202 Status HloComputation::RemoveParameter(int64_t param_no) {
203   CHECK_GE(param_no, 0);
204   CHECK_LT(param_no, param_instructions_.size());
205   HloInstruction* param_instruction = param_instructions_[param_no];
206   auto param_instruction_iterator = param_instructions_.begin() + param_no;
207   param_instructions_.erase(param_instruction_iterator);
208   // Throw removed fused parameter instruction away.
209   TF_RETURN_IF_ERROR(ForceRemoveInstruction(param_instruction));
210 
211   while (param_no < param_instructions_.size()) {
212     param_instruction = param_instructions_[param_no];
213     HloInstruction* new_instr =
214         AddInstructionInternal(HloInstruction::CreateParameter(
215             param_no, param_instruction->shape(), StrCat("param_", param_no)));
216     TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
217     param_instructions_[param_no] = new_instr;
218     TF_RETURN_IF_ERROR(ForceRemoveInstruction(param_instruction));
219     param_no++;
220   }
221 
222   return OkStatus();
223 }
224 
ReplaceParameter(int64_t param_no,std::unique_ptr<HloInstruction> instruction)225 HloInstruction* HloComputation::ReplaceParameter(
226     int64_t param_no, std::unique_ptr<HloInstruction> instruction) {
227   CHECK_GE(param_no, 0);
228   CHECK_LT(param_no, param_instructions_.size());
229   CHECK(instruction->opcode() == HloOpcode::kParameter);
230   CHECK(!IsFusionComputation() ||
231         fusion_instruction_->operand_count() == param_instructions_.size());
232 
233   instruction->set_parent(this);
234   HloInstruction* new_instruction =
235       AddInstructionInternal(std::move(instruction));
236   HloInstruction* old_instruction = param_instructions_[param_no];
237   CHECK(
238       old_instruction->ReplaceAllUsesWithDifferentShape(new_instruction).ok());
239   param_instructions_[param_no] = new_instruction;
240   CHECK(RemoveInstruction(old_instruction).ok());
241   return new_instruction;
242 }
243 
RemoveUnusedParametersFromFusedComputation()244 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
245   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
246 }
247 
RemoveUnusedParametersFromAnyComputation()248 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
249   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
250 }
251 
RemoveUnusedParametersImpl(bool allow_non_fusion)252 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
253   CHECK(allow_non_fusion || IsFusionComputation());
254   int64_t removed = 0;
255   for (int64_t i = 0; i < param_instructions_.size(); ++i) {
256     HloInstruction* param_instruction = param_instructions_[i];
257     if (param_instruction->IsDead()) {
258       TF_RETURN_IF_ERROR(
259           RemoveInstructionImpl(param_instruction, allow_non_fusion));
260       ++removed;
261       continue;
262     }
263 
264     if (removed > 0) {
265       const int64_t param_no = i - removed;
266       HloInstruction* new_instr = AddInstructionInternal(
267           HloInstruction::CreateParameter(param_no, param_instruction->shape(),
268                                           StrCat("param_", param_no)));
269       TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
270       param_instructions_[param_no] = new_instr;
271       TF_RETURN_IF_ERROR(
272           RemoveInstructionImpl(param_instruction, allow_non_fusion));
273     }
274   }
275   param_instructions_.resize(param_instructions_.size() - removed);
276   return OkStatus();
277 }
278 
IsSafelyRemovable(const HloInstruction * instruction)279 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
280   // If the instruction has control predecessors or successors then we cannot
281   // remove the instruction without violating ordering constraints (added, for
282   // example, to avert interference due to buffer aliasing).
283   if (!instruction->control_predecessors().empty() ||
284       !instruction->control_successors().empty()) {
285     return false;
286   }
287 
288   if (instruction->opcode() == HloOpcode::kParameter &&
289       !IsFusionComputation()) {
290     return false;
291   }
292 
293   return true;
294 }
295 
HasSideEffect() const296 bool HloComputation::HasSideEffect() const {
297   for (auto* instruction : instructions()) {
298     if (instruction->HasSideEffect()) {
299       return true;
300     }
301   }
302   return false;
303 }
304 
IsMarkedAsDead(const HloInstruction * inst)305 bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
306   return inst->IsMarkedAsDead();
307 }
308 
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)309 Status HloComputation::RemoveInstructionAndUnusedOperands(
310     HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
311   TF_RET_CHECK(root_instruction() != instruction);
312 
313   TF_RET_CHECK(instruction->IsDead());
314   TF_RET_CHECK(IsSafelyRemovable(instruction))
315       << "Cannot remove instruction: " << instruction->ToString();
316   absl::flat_hash_set<HloInstruction*> removed;
317   std::queue<HloInstruction*> worklist;
318   worklist.push(instruction);
319   while (!worklist.empty()) {
320     HloInstruction* item = worklist.front();
321     worklist.pop();
322 
323     if (removed.contains(item) || !item->IsDead() || !IsSafelyRemovable(item) ||
324         (item->HasSideEffect() && item != instruction)) {
325       continue;
326     }
327     for (int i = 0; i < item->operand_count(); ++i) {
328       worklist.push(item->mutable_operand(i));
329     }
330 
331     if (cleanup) {
332       cleanup(item);
333     }
334     TF_RETURN_IF_ERROR(RemoveInstruction(item));
335     removed.insert(item);
336   }
337   return OkStatus();
338 }
339 
RemoveInstruction(HloInstruction * instruction)340 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
341   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
342 }
343 
ForceRemoveInstruction(HloInstruction * instruction)344 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
345   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
346 }
347 
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)348 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
349                                              bool ignore_safety_check) {
350   VLOG(2) << "Removing instruction " << instruction->name()
351           << " from computation " << name();
352   TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
353       << "cannot remove instruction: " << instruction->ToString();
354   TF_RET_CHECK(instruction->IsDead()) << "instruction " << instruction->name()
355                                       << " is live and cannot be removed";
356   TF_RET_CHECK(instruction->control_predecessors().empty())
357       << "instruction " << instruction->name()
358       << " has control predecessors and cannot be removed";
359   TF_RET_CHECK(instruction->control_successors().empty())
360       << "instruction " << instruction->name()
361       << " has control successors and cannot be removed";
362 
363   auto inst_it = instruction_iterators_.find(instruction);
364   TF_RET_CHECK(inst_it != instruction_iterators_.end());
365   (*inst_it->second)->set_parent(nullptr);
366   to_be_deleted_.emplace_back(inst_it->second->release());
367   to_be_deleted_.back()->DetachFromOperandsAndUsers();
368   // Clear all operands to avoid Null operands.
369   to_be_deleted_.back()->RemoveAllOperands();
370   to_be_deleted_.back()->ClearCalledComputations();
371   to_be_deleted_.back()->MarkAsDead();
372   instructions_.erase(inst_it->second);
373   instruction_iterators_.erase(inst_it);
374   return OkStatus();
375 }
376 
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)377 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
378                                           bool accept_different_shape) {
379   // The shape of the root (ignoring layout) is an invariant of the computation
380   // for non-fusion cases.
381   if (!IsFusionComputation() && !accept_different_shape) {
382     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
383                                 root_instruction_->shape()))
384         << new_root_instruction->shape() << " is incompatible with "
385         << root_instruction_->shape();
386   }
387   bool root_found = false;
388   for (auto& instruction : instructions_) {
389     if (new_root_instruction == instruction.get()) {
390       root_found = true;
391       break;
392     }
393   }
394   DCHECK(root_found);
395 
396   if (parent() && parent()->has_entry_computation() &&
397       parent()->entry_computation() == this) {
398     if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
399                                        root_instruction_->shape())) {
400       // Rebuild input output alias config now that we have a new output shape.
401       parent()->input_output_alias_config() =
402           HloInputOutputAliasConfig(new_root_instruction->shape());
403     }
404   }
405 
406   root_instruction_ = new_root_instruction;
407 }
408 
409 namespace {
410 
411 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)412 void ComputeComputationPostOrder(HloComputation* computation,
413                                  absl::flat_hash_set<HloComputation*>* visited,
414                                  std::vector<HloComputation*>* post_order) {
415   if (visited->insert(computation).second) {
416     for (auto* instruction : computation->instructions()) {
417       for (HloComputation* called_computation :
418            instruction->called_computations()) {
419         ComputeComputationPostOrder(called_computation, visited, post_order);
420       }
421     }
422     post_order->push_back(computation);
423   }
424 }
425 
GetChannelId(const HloInstruction & inst)426 std::optional<int64_t> GetChannelId(const HloInstruction& inst) {
427   // Note that we only include Send and RecvDone, as we want to create a
428   // dependency between those, but not SendDone and Recv.
429   switch (inst.opcode()) {
430     case HloOpcode::kSend:
431     case HloOpcode::kRecvDone:
432     case HloOpcode::kAllReduce:
433     case HloOpcode::kAllGather:
434     case HloOpcode::kAllToAll:
435     case HloOpcode::kCollectivePermute:
436     case HloOpcode::kReduceScatter:
437       return inst.channel_id();
438     default:
439       return std::nullopt;
440   }
441 }
442 
443 }  // namespace
444 
ComputeInstructionPostOrder(HloInstruction * root,HloComputation::ChannelDependencyGroup & channel_dependencies,absl::flat_hash_map<HloInstruction *,VisitState> & visited,std::vector<HloInstruction * > & post_order) const445 void HloComputation::ComputeInstructionPostOrder(
446     HloInstruction* root,
447     HloComputation::ChannelDependencyGroup& channel_dependencies,
448     absl::flat_hash_map<HloInstruction*, VisitState>& visited,
449     std::vector<HloInstruction*>& post_order) const {
450   std::vector<HloInstruction*> dfs_stack = {root};
451   while (!dfs_stack.empty()) {
452     HloInstruction& current = *dfs_stack.back();
453 
454     auto result = visited.insert({&current, kVisiting});
455     if (!result.second) {  // We've already seen this instruction.
456       dfs_stack.pop_back();
457       if (result.first->second != kVisited) {
458         CHECK_EQ(current.parent(), this)
459             << "Instruction " << current.name()
460             << " is not in the current computation (" << name() << ").";
461         post_order.push_back(&current);
462         result.first->second = kVisited;
463       }
464       continue;
465     }
466 
467     // Add channel dependencies.
468     // A RecvDone op must be preceded by the corresponding Send op.
469     // Collectives with the same channel ID must be performed together, as these
470     // represent MPMD-partitioned that will later be split into separate modules
471     // and the order must be preserved.
472     std::optional<int64_t> channel_id =
473         ((&current != root) && (current.opcode() != HloOpcode::kSend))
474             ? GetChannelId(current)
475             : std::nullopt;
476     if (channel_id) {
477       auto it = channel_dependencies.find(*channel_id);
478       if (it != channel_dependencies.end()) {
479         dfs_stack.insert(dfs_stack.end(), it->second.begin(), it->second.end());
480         channel_dependencies.erase(it);
481       }
482     }
483 
484     // Add the operands to the stack in reverse order so the first operand is
485     // processed first. This will produce a more natural ordering and a nicer
486     // result for things like HLO stringification.
487     const HloInstruction::InstructionVector& operands = current.operands();
488     dfs_stack.insert(dfs_stack.end(), operands.rbegin(), operands.rend());
489 
490     const std::vector<HloInstruction*>& predecessors =
491         current.control_predecessors();
492     dfs_stack.insert(dfs_stack.end(), predecessors.begin(), predecessors.end());
493   }
494 }
495 
496 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const497 HloComputation::ComputeChannelDependencies() const {
498   if (parent() && parent()->config().has_static_device_assignment() &&
499       (parent()->config().static_device_assignment().computation_count() == 1 ||
500        parent()->config().use_spmd_partitioning())) {
501     return {};
502   }
503 
504   ChannelDependencyGroup channel_dependencies;
505   for (const auto& instruction : instructions_) {
506     std::optional<int64_t> channel_id = GetChannelId(*instruction);
507     if (channel_id)
508       channel_dependencies[*channel_id].push_back(instruction.get());
509   }
510   return channel_dependencies;
511 }
512 
HasOnlyTraceUsers(const HloInstruction * instruction)513 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
514   return absl::c_all_of(instruction->users(),
515                         [](HloInstruction* user) { return false; });
516 }
517 
MakeInstructionPostOrder() const518 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
519   ChannelDependencyGroup channel_dependencies = ComputeChannelDependencies();
520   std::vector<HloInstruction*> post_order;
521   post_order.reserve(instruction_count());
522   std::vector<HloInstruction*> trace_instructions;
523   absl::flat_hash_map<HloInstruction*, VisitState> visited;
524   visited.reserve(instruction_count());
525   for (auto& instruction : instructions_) {
526     if (HasOnlyTraceUsers(instruction.get())) {
527       ComputeInstructionPostOrder(instruction.get(), channel_dependencies,
528                                   visited, post_order);
529     }
530   }
531   post_order.insert(post_order.end(), trace_instructions.begin(),
532                     trace_instructions.end());
533   CHECK_EQ(instructions_.size(), post_order.size())
534       << "number of instructions does not match post order size";
535   return post_order;
536 }
537 
MakeEmbeddedComputationsList() const538 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
539     const {
540   absl::flat_hash_set<HloComputation*> visited;
541   std::vector<HloComputation*> post_order;
542 
543   // To avoid special handling of this computation, cast away const of
544   // 'this'. 'this' is immediately removed from the post order after
545   // construction.
546   //
547   // TODO(b/78350259): This violates const-correctness, since while the original
548   // computation is not returned, we still retrieve non-const computations from
549   // a const one. Consider also avoiding const for HloComputation, or review XLA
550   // for const-correctness of non-HloInstruction* types like this.
551   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
552                               &post_order);
553 
554   // We don't want to include this computation in the post order.
555   CHECK_EQ(this, post_order.back());
556   post_order.pop_back();
557 
558   return post_order;
559 }
560 
ToString(const HloPrintOptions & options) const561 std::string HloComputation::ToString(const HloPrintOptions& options) const {
562   return std::string(ToCord(options));
563 }
564 
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const565 std::string HloComputation::ToString(
566     const HloPrintOptions& options,
567     absl::Span<const HloInstruction* const> instruction_order) const {
568   return std::string(ToCord(options, instruction_order));
569 }
570 
ToCord(const HloPrintOptions & options) const571 absl::Cord HloComputation::ToCord(const HloPrintOptions& options) const {
572   return ToCord(options, MakeInstructionPostOrder());
573 }
574 
ToCord(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const575 absl::Cord HloComputation::ToCord(
576     const HloPrintOptions& options,
577     absl::Span<const HloInstruction* const> instruction_order) const {
578   CHECK_EQ(instruction_order.size(), instruction_count());
579   const std::string tab(2 * options.indent_amount(), ' ');
580 
581   absl::Cord result;
582   result.Append(tab);
583 
584   if (!options.is_in_nested_computation()) {
585     if (options.print_percent()) {
586       result.Append("%");
587     }
588     if (options.print_ids()) {
589       // When print_ids() is false, exclude entry computation's name because it
590       // includes and leads to non-deterministic fingerprint.
591       result.Append(name());
592       result.Append(" ");
593     }
594   }
595 
596   if (options.print_program_shape()) {
597     result.Append(
598         ShapeUtil::HumanString(ComputeProgramShape(options.print_ids())));
599     result.Append(" ");
600   }
601   result.Append("{\n");
602 
603   {
604     // Print the instructions in this computation.
605     HloPrintOptions new_options =
606         HloPrintOptions(options)
607             .set_indent_amount(options.indent_amount() + 1)
608             .set_is_in_nested_computation(true);
609 
610     const std::string new_tab(2 * new_options.indent_amount(), ' ');
611 
612     CanonicalNameMap name_map;
613     for (const HloInstruction* const instruction : instruction_order) {
614       DCHECK_EQ(this, instruction->parent());
615       result.Append(new_tab);
616       if (instruction == root_instruction_) {
617         result.Append("ROOT ");
618       }
619       result.Append(
620           instruction->ToStringWithCanonicalNameMap(new_options, &name_map));
621       result.Append("\n");
622     }
623   }
624 
625   result.Append(tab);
626   result.Append("}");
627   if (options.print_ids() && !IsMainThread()) {
628     // When print_ids() is false, exclude entry computation's thread name
629     // because it includes and leads to non-deterministic fingerprint.
630     result.Append(StrCat(", execution_thread=\"", execution_thread(), "\""));
631   }
632   return result;
633 }
634 
ToProto() const635 HloComputationProto HloComputation::ToProto() const {
636   HloComputationProto proto;
637   CHECK(unique_id_ != -1)
638       << "This computation does not have a valid id. Please make sure the "
639          "computation is inside a module before dumping it.";
640   proto.set_id(unique_id_);
641   proto.set_name(name_);
642   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
643     HloInstructionProto instruction_proto = instruction->ToProto();
644     proto.add_instructions()->Swap(&instruction_proto);
645   }
646   proto.set_root_id(root_instruction()->unique_id());
647   *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
648   proto.set_is_fusion_computation(is_fusion_computation_);
649   proto.set_execution_thread(IsMainThread() ? ""
650                                             : std::string(execution_thread()));
651   return proto;
652 }
653 
654 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64_t,HloComputation * > & computation_map,bool prohibit_empty_literal)655 HloComputation::CreateFromProto(
656     const HloComputationProto& proto,
657     const absl::flat_hash_map<int64_t, HloComputation*>& computation_map,
658     bool prohibit_empty_literal) {
659   absl::flat_hash_map<int64_t, HloInstruction*> instruction_map;
660   absl::flat_hash_map<HloInstruction*, int64_t> to_proto_id;
661   std::vector<std::unique_ptr<HloInstruction>> instructions;
662   int64_t parameter_count = 0;
663   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
664     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
665                         HloInstruction::CreateFromProto(
666                             instruction_proto, instruction_map, computation_map,
667                             prohibit_empty_literal));
668     if (instruction->opcode() == HloOpcode::kParameter) {
669       parameter_count++;
670     }
671     TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
672     instruction_map[instruction_proto.id()] = instruction.get();
673     to_proto_id[instruction.get()] = instruction_proto.id();
674     instructions.push_back(std::move(instruction));
675   }
676 
677   TF_RET_CHECK(proto.root_id() != -1);
678   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
679   HloInstruction* root = instruction_map.at(proto.root_id());
680 
681   // Sort the instructions in the proto id's order.
682   absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
683                                  const std::unique_ptr<HloInstruction>& b) {
684     return to_proto_id[a.get()] < to_proto_id[b.get()];
685   });
686 
687   TF_RETURN_IF_ERROR([&]() -> Status {
688     std::vector<bool> parameters_seen(parameter_count);
689     int parameters_seen_count = 0;
690     for (auto& instruction : instructions) {
691       if (instruction->opcode() == HloOpcode::kParameter) {
692         int64_t param_no = instruction->parameter_number();
693         TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
694             << "Invalid parameter number.  Expected [0, " << parameter_count
695             << "), got " << param_no;
696         TF_RET_CHECK(!parameters_seen[param_no])
697             << "Parameter number " << param_no
698             << " already allocated in this computation";
699         parameters_seen[param_no] = true;
700         parameters_seen_count++;
701       }
702     }
703     TF_RET_CHECK(parameters_seen_count == parameter_count)
704         << "Not all parameters in range [0, " << parameter_count
705         << ") were referenced";
706     return OkStatus();
707   }());
708 
709   auto computation = absl::WrapUnique(
710       new HloComputation(proto.name(), parameter_count, &instructions, root,
711                          /*fusion_instruction=*/nullptr));
712   computation->unique_id_ = proto.id();
713   computation->is_fusion_computation_ = proto.is_fusion_computation();
714   if (!proto.execution_thread().empty()) {
715     computation->SetExecutionThread(proto.execution_thread());
716   }
717   return std::move(computation);
718 }
719 
AppendInstructionsIntoCalledComputation(absl::Span<HloInstruction * const> instructions_to_append,HloInstruction * caller)720 void HloComputation::AppendInstructionsIntoCalledComputation(
721     absl::Span<HloInstruction* const> instructions_to_append,
722     HloInstruction* caller) {
723   HloInstruction* root = instructions_to_append.front();
724   TF_CHECK_OK(root->ReplaceAllUsesWith(caller));
725   if (root == root_instruction()) {
726     set_root_instruction(caller);
727   }
728   TF_CHECK_OK(RemoveInstruction(root));
729   for (size_t i = 1; i < instructions_to_append.size(); ++i) {
730     HloInstruction* instruction = instructions_to_append[i];
731     caller->AppendInstructionIntoCalledComputation(instruction);
732     if (instruction->IsDead()) {
733       TF_CHECK_OK(RemoveInstruction(instruction));
734     }
735   }
736 }
737 
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)738 HloInstruction* HloComputation::CreateFusionInstruction(
739     absl::Span<HloInstruction* const> instructions_to_fuse,
740     HloInstruction::FusionKind fusion_kind) {
741   HloInstruction* root = instructions_to_fuse.front();
742   HloInstruction* fusion_instruction = AddInstruction(
743       HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
744   AppendInstructionsIntoCalledComputation(instructions_to_fuse,
745                                           fusion_instruction);
746   return fusion_instruction;
747 }
748 
CreateCallInstruction(absl::Span<HloInstruction * const> instructions_to_call)749 HloInstruction* HloComputation::CreateCallInstruction(
750     absl::Span<HloInstruction* const> instructions_to_call) {
751   HloInstruction* root = instructions_to_call.front();
752   HloInstruction* call_instruction =
753       AddInstruction(HloInstruction::CreateCall(root->shape(), root));
754   AppendInstructionsIntoCalledComputation(instructions_to_call,
755                                           call_instruction);
756   return call_instruction;
757 }
758 
CreateAsyncInstructions(HloInstruction * instruction,absl::Span<const Shape> context_shapes,absl::string_view async_execution_thread)759 StatusOr<HloInstruction*> HloComputation::CreateAsyncInstructions(
760     HloInstruction* instruction, absl::Span<const Shape> context_shapes,
761     absl::string_view async_execution_thread) {
762   Builder builder("async_computation");
763   std::vector<HloInstruction*> parameters(instruction->operand_count());
764   std::vector<Shape> parameter_shapes(instruction->operand_count());
765   for (int i = 0; i < instruction->operand_count(); ++i) {
766     const Shape& parameter_shape = instruction->operand(i)->shape();
767     parameters[i] = builder.AddInstruction(HloInstruction::CreateParameter(
768         i, parameter_shape, absl::StrCat("param_", i)));
769     parameter_shapes[i] = parameter_shape;
770   }
771   HloInstruction* root = builder.AddInstruction(
772       instruction->CloneWithNewOperands(instruction->shape(), parameters));
773   HloComputation* async_computation =
774       parent_->AddEmbeddedComputation(builder.Build(root));
775   std::vector<Shape> start_shapes = {
776       ShapeUtil::MakeTupleShape(parameter_shapes), root->shape()};
777   for (const Shape& context_shape : context_shapes) {
778     start_shapes.push_back(context_shape);
779   }
780   HloInstruction* async_start = AddInstruction(HloInstruction::CreateAsyncStart(
781       ShapeUtil::MakeTupleShape(start_shapes), instruction->operands(),
782       async_computation, /*async_group_id=*/std::nullopt,
783       async_execution_thread));
784   HloInstruction* async_done = AddInstruction(HloInstruction::CreateAsyncDone(
785       root->shape(), async_start, async_computation,
786       /*async_group_id=*/std::nullopt, async_execution_thread));
787   async_start->set_metadata(instruction->metadata());
788   async_start->CopyBackendConfigFrom(instruction);
789   async_done->set_metadata(instruction->metadata());
790   async_done->CopyBackendConfigFrom(instruction);
791   TF_RETURN_IF_ERROR(ReplaceInstruction(instruction, async_done));
792   return async_done;
793 }
794 
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)795 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
796     HloInstruction* instruction, ShapeIndex* index,
797     const std::function<
798         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
799                         HloComputation* computation)>& copy_leaf) {
800   if (instruction->shape().IsTuple()) {
801     std::vector<HloInstruction*> elements;
802     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
803          i++) {
804       HloInstruction* gte =
805           AddInstruction(HloInstruction::CreateGetTupleElement(
806               ShapeUtil::GetTupleElementShape(instruction->shape(), i),
807               instruction, i));
808 
809       index->push_back(i);
810       TF_ASSIGN_OR_RETURN(HloInstruction * element,
811                           DeepCopyHelper(gte, index, copy_leaf));
812       elements.push_back(element);
813       index->pop_back();
814     }
815     return AddInstruction(HloInstruction::CreateTuple(elements));
816   }
817   if (instruction->shape().IsToken()) {
818     // Tokens have no on-device representation and cannot be copied. Pass
819     // through transparently.
820     return instruction;
821   }
822 
823   // Array shape.
824   TF_RET_CHECK(instruction->shape().IsArray());
825   return copy_leaf(instruction, *index, this);
826 }
827 
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)828 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
829     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
830     ShapeTree<HloInstruction*>* copies_added) {
831   if (instruction->parent() != this) {
832     return FailedPrecondition(
833         "Can't deep copy instruction %s: instruction is not in computation %s",
834         instruction->name(), name());
835   }
836   if (indices_to_copy != nullptr &&
837       !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
838     return FailedPrecondition(
839         "Can't deep copy instruction %s: given shape tree of indices to copy "
840         "has incompatible shapes: %s vs. %s",
841         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
842         ShapeUtil::HumanString(indices_to_copy->shape()));
843   }
844 
845   ShapeIndex index;
846   auto copy_leaf = [indices_to_copy, copies_added](
847                        HloInstruction* leaf, const ShapeIndex& leaf_index,
848                        HloComputation* computation) {
849     if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
850       HloInstruction* copy = computation->AddInstruction(
851           HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
852       if (copies_added != nullptr) {
853         *copies_added->mutable_element(leaf_index) = copy;
854       }
855       return copy;
856     }
857     // Elements which are not to be copied are passed through
858     // transparently.
859     return leaf;
860   };
861   return DeepCopyHelper(instruction, &index, copy_leaf);
862 }
863 
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)864 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
865     HloInstruction* instruction,
866     const std::function<
867         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
868                         HloComputation* computation)>& copy_leaf) {
869   if (instruction->parent() != this) {
870     return FailedPrecondition(
871         "Can't deep copy instruction %s: instruction is not in computation %s",
872         instruction->name(), name());
873   }
874   ShapeIndex index;
875   return DeepCopyHelper(instruction, &index, copy_leaf);
876 }
877 
ComputeProgramShape(bool include_ids) const878 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
879   ProgramShape program_shape;
880 
881   for (auto* param_instruction : param_instructions_) {
882     *program_shape.add_parameters() = param_instruction->shape();
883     *program_shape.add_parameter_names() =
884         PrintName(param_instruction->name(), include_ids);
885   }
886   *program_shape.mutable_result() = root_instruction_->shape();
887 
888   return program_shape;
889 }
890 
EqualInternal(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values,bool ignore_thread) const891 bool HloComputation::EqualInternal(const HloComputation& other,
892                                    bool is_layout_sensitive,
893                                    bool ignore_channel_id_values,
894                                    bool ignore_thread) const {
895   if (this == &other) {
896     return true;
897   }
898   absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
899       visited;
900   std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
901 
902   worklist.push_back({root_instruction(), other.root_instruction()});
903 
904   while (!worklist.empty()) {
905     auto pair = worklist.back();
906     worklist.pop_back();
907 
908     if (visited.contains(pair)) {
909       continue;
910     }
911     visited.emplace(pair);
912     // TODO(b/123082518): Avoid recursively invoking Equal because it may
913     // cause a stack overflow with deeply nested subcomputations.
914     auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
915       return true;
916     };
917     auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
918       return a->EqualInternal(*b, is_layout_sensitive, ignore_channel_id_values,
919                               ignore_thread);
920     };
921     bool identical_ignoring_operands =
922         ignore_channel_id_values
923             ? pair.first->IdenticalIgnoringChannelIdValues(
924                   *pair.second, operands_eq, comp_eq, is_layout_sensitive)
925             : pair.first->Identical(*pair.second, operands_eq, comp_eq,
926                                     is_layout_sensitive);
927     if (!identical_ignoring_operands) {
928       return false;
929     }
930     for (size_t i = 0; i < pair.first->operands().size(); ++i) {
931       worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
932     }
933   }
934 
935   if (!ignore_thread) {
936     return execution_thread() == other.execution_thread();
937   }
938   return true;
939 }
940 
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)941 Status HloComputation::ReplaceWithNewInstruction(
942     HloInstruction* old_instruction,
943     std::unique_ptr<HloInstruction> new_instruction) {
944   return ReplaceInstruction(old_instruction,
945                             AddInstruction(std::move(new_instruction)));
946 }
947 
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)948 Status HloComputation::ReplaceWithNewEntryComputationParameter(
949     HloInstruction* old_instruction,
950     std::unique_ptr<HloInstruction> new_instruction) {
951   return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
952                                                  std::move(new_instruction)));
953 }
954 
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction,bool preserve_sharding)955 StatusOr<bool> HloComputation::ReplaceInstruction(
956     HloInstruction* old_instruction, HloInstruction* new_instruction,
957     bool preserve_sharding) {
958   TF_RET_CHECK(
959       ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
960       << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
961       << ShapeUtil::HumanString(new_instruction->shape());
962   return ReplaceInstructionWithDifferentShape(old_instruction, new_instruction,
963                                               preserve_sharding);
964 }
965 
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)966 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
967                                           HloInstruction* new_instruction) {
968   TF_ASSIGN_OR_RETURN(bool changed,
969                       ReplaceInstruction(old_instruction, new_instruction,
970                                          /*preserve_sharding=*/false));
971   DCHECK(changed);
972   return OkStatus();
973 }
974 
ReplaceInstructionWithDifferentShape(HloInstruction * old_instruction,HloInstruction * new_instruction,bool preserve_sharding)975 StatusOr<bool> HloComputation::ReplaceInstructionWithDifferentShape(
976     HloInstruction* old_instruction, HloInstruction* new_instruction,
977     bool preserve_sharding) {
978   if (preserve_sharding && new_instruction->has_sharding() &&
979       old_instruction->has_sharding() &&
980       !new_instruction->has_compatible_sharding(old_instruction)) {
981     VLOG(10) << "Skipping replacement due to incompatible sharding";
982     return false;
983   }
984   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
985            << new_instruction->ToString();
986   // Try to add metadata for HLO instructions that are created to replace
987   // existing HLO instructions (e.g. during optimizations). The assumption is
988   // that the old instruction and the new instruction would perform the same
989   // function, and that they would be correlated to the same TF op. This might
990   // not always be correct since HLO optimizations can cross TF op boundaries.
991   // But still this seems to be better than nothing.
992   bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
993                            !old_instruction->metadata().op_name().empty();
994   bool overwrite_pass_id =
995       new_instruction->metadata().op_name().empty() &&
996       new_instruction->metadata().logical_creation_pass_id() == 0 &&
997       old_instruction->metadata().logical_creation_pass_id() != 0;
998   if (overwrite_op_name || overwrite_pass_id) {
999     new_instruction->set_metadata(old_instruction->metadata());
1000   }
1001   if (new_instruction->frontend_attributes().map().empty()) {
1002     new_instruction->set_frontend_attributes(
1003         old_instruction->frontend_attributes());
1004   }
1005 
1006   // Like the metadata above, if the user didn't specify any sharding
1007   // information on the new instruction we should copy the old sharding
1008   // information (if any).
1009   if (!new_instruction->has_sharding()) {
1010     new_instruction->set_sharding(old_instruction->sharding_ptr());
1011   }
1012 
1013   TF_RETURN_IF_ERROR(
1014       old_instruction->ReplaceAllUsesWithDifferentShape(new_instruction));
1015 
1016   // Preserve the old instruction's name if the new and old instruction have the
1017   // same opcode.  This makes it easier to follow instructions as they're
1018   // mutated through passes.
1019   if (old_instruction->opcode() == new_instruction->opcode() &&
1020       (old_instruction->opcode() != HloOpcode::kCustomCall ||
1021        old_instruction->custom_call_target() ==
1022            new_instruction->custom_call_target())) {
1023     new_instruction->SetAndSanitizeName(old_instruction->name());
1024   }
1025 
1026   TF_RETURN_IF_ERROR(RemoveInstructionAndUnusedOperands(old_instruction));
1027   return true;
1028 }
1029 
ReplaceInstructionWithDifferentShape(HloInstruction * old_instruction,HloInstruction * new_instruction)1030 Status HloComputation::ReplaceInstructionWithDifferentShape(
1031     HloInstruction* old_instruction, HloInstruction* new_instruction) {
1032   TF_ASSIGN_OR_RETURN(bool changed, ReplaceInstructionWithDifferentShape(
1033                                         old_instruction, new_instruction,
1034                                         /*preserve_sharding=*/false));
1035   DCHECK(changed);
1036   return OkStatus();
1037 }
1038 
CollectUnreachableRoots() const1039 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
1040   std::vector<HloInstruction*> unreachable_roots;
1041   for (auto* instruction : instructions()) {
1042     if (instruction->IsDead() && instruction->control_successors().empty()) {
1043       unreachable_roots.push_back(instruction);
1044     }
1045   }
1046   VLOG(3) << "Unreachable roots:"
1047           << absl::StrJoin(unreachable_roots, "\n\t",
1048                            [](std::string* out, const HloInstruction* hlo) {
1049                              absl::StrAppend(out, hlo->ToString());
1050                            });
1051   return unreachable_roots;
1052 }
1053 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const1054 Status HloComputation::AcceptWithOperandOrder(
1055     DfsHloVisitor* visitor,
1056     const HloInstruction::CompareFunction& operand_order) const {
1057   // Visit unreachable roots. Beware that the visitor might delete the currently
1058   // visited root, which would invalidate iterators if the unreachable roots
1059   // weren't computed ahead of time.
1060   for (HloInstruction* root : CollectUnreachableRoots()) {
1061     TF_RETURN_IF_ERROR(
1062         root->AcceptWithOperandOrder(visitor, operand_order,
1063                                      /*call_finish_visit=*/false));
1064   }
1065   // Visit the computation root instruction last.
1066   return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
1067                                                     /*call_finish_visit=*/true);
1068 }
1069 
Clone(const std::string & suffix,HloCloneContext * context)1070 std::unique_ptr<HloComputation> HloComputation::Clone(
1071     const std::string& suffix, HloCloneContext* context) {
1072   return CloneWithReplacements(
1073       /*replacements=*/nullptr,
1074       /*extra_parameters=*/{}, context, suffix);
1075 }
1076 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const std::string & suffix)1077 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1078     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1079     HloCloneContext* context, const std::string& suffix) {
1080   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1081       replacements;
1082   replacements.emplace(std::move(r1));
1083   return CloneWithReplacements(&replacements, /*extra_parameters=*/{}, context,
1084                                suffix);
1085 }
1086 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const std::string & suffix)1087 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1088     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1089     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1090     HloCloneContext* context, const std::string& suffix) {
1091   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1092       replacements;
1093   replacements.emplace(std::move(r1));
1094   replacements.emplace(std::move(r2));
1095   return CloneWithReplacements(&replacements, /*extra_parameters=*/{}, context,
1096                                suffix);
1097 }
1098 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const std::string & suffix)1099 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1100     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1101     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1102     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
1103     HloCloneContext* context, const std::string& suffix) {
1104   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1105       replacements;
1106   replacements.emplace(std::move(r1));
1107   replacements.emplace(std::move(r2));
1108   replacements.emplace(std::move(r3));
1109   return CloneWithReplacements(&replacements, /*extra_parameters=*/{}, context,
1110                                suffix);
1111 }
1112 
1113 namespace {
1114 
1115 // Sorts unordered_instructions according to the order of ordered_instructions,
1116 // using MappedPtrContainerSorter. context and replace are used to map
1117 // instructions in ordered_instructions to instructions in
1118 // unordered_instructions. Unmapped parameter instructions are placed just after
1119 // the last parameter instruction in the sorted mapped instruction order. All
1120 // other mapped instructions are placed at the end.
SortClonedInstructions(const HloCloneContext & context,const std::function<const HloInstruction * (const HloInstruction *)> & replace,const HloComputation & computation,const HloComputation::InstructionList & ordered_instructions,std::vector<std::unique_ptr<HloInstruction>> & unordered_instructions)1121 void SortClonedInstructions(
1122     const HloCloneContext& context,
1123     const std::function<const HloInstruction*(const HloInstruction*)>& replace,
1124     const HloComputation& computation,
1125     const HloComputation::InstructionList& ordered_instructions,
1126     std::vector<std::unique_ptr<HloInstruction>>& unordered_instructions) {
1127   using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
1128   InstructionSorter::MapPtrFn instruction_mapper =
1129       [&context, &replace](const HloInstruction* i) {
1130         return context.FindInstruction(replace(i));
1131       };
1132   size_t num_mapped_instructions = 0;
1133   size_t mapped_index_of_last_parameter_plus_one = 0;
1134   for (const auto& instruction : ordered_instructions) {
1135     if (!instruction_mapper(instruction.get())) {
1136       continue;
1137     }
1138     ++num_mapped_instructions;
1139     if (!dynamic_cast<const HloParameterInstruction*>(instruction.get())) {
1140       continue;
1141     }
1142     mapped_index_of_last_parameter_plus_one = num_mapped_instructions;
1143   }
1144   InstructionSorter::UnmappedPtrIndexFn unmapped_ptr_index =
1145       [num_mapped_instructions,
1146        mapped_index_of_last_parameter_plus_one](const HloInstruction* i) {
1147         if (dynamic_cast<const HloParameterInstruction*>(i)) {
1148           if (num_mapped_instructions > 0 &&
1149               mapped_index_of_last_parameter_plus_one > 0) {
1150             return mapped_index_of_last_parameter_plus_one - 1;
1151           }
1152           return InstructionSorter::IndexBeforeMappedElementsFn()(i);
1153         }
1154         return InstructionSorter::IndexAfterMappedElementsFn()(i);
1155       };
1156   auto status =
1157       InstructionSorter::Sort(instruction_mapper, unmapped_ptr_index,
1158                               ordered_instructions, unordered_instructions);
1159   if (!status.ok()) {
1160     LOG(ERROR) << "Failed to reorder instructions while cloning computation: "
1161                << computation.name() << "; " << status;
1162   }
1163 }
1164 
1165 // For cloned instructions, sorts their users, control predecessors, and control
1166 // successors, according to the orders of those lists in the original
1167 // instructions, before cloning. context and replace help us to map original
1168 // instructions to cloned instructions, in addition to creating a list of
1169 // cloned instructions.
SortClonedInstructionUsersAndControlLists(const HloCloneContext & context,const std::function<const HloInstruction * (const HloInstruction *)> & replace,const HloComputation::InstructionList & sorted_instructions)1170 void SortClonedInstructionUsersAndControlLists(
1171     const HloCloneContext& context,
1172     const std::function<const HloInstruction*(const HloInstruction*)>& replace,
1173     const HloComputation::InstructionList& sorted_instructions) {
1174   using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
1175   InstructionSorter::MapPtrFn instruction_mapper =
1176       [&context, &replace](const HloInstruction* i) {
1177         return context.FindInstruction(replace(i));
1178       };
1179   for (const std::unique_ptr<HloInstruction>& instruction :
1180        sorted_instructions) {
1181     HloInstruction* cloned_instruction =
1182         context.FindInstruction(replace(instruction.get()));
1183     if (!cloned_instruction) {
1184       continue;
1185     }
1186     cloned_instruction->SortInstructionUsersAndControlLists(instruction_mapper,
1187                                                             *instruction);
1188   }
1189 }
1190 
1191 }  // namespace
1192 
CloneWithReplacements(const absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> * replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const std::string & suffix,const HloInstruction * new_root)1193 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
1194     const absl::flat_hash_map<const HloInstruction*,
1195                               std::unique_ptr<HloInstruction>>* replacements,
1196     absl::Span<const HloInstruction* const> extra_parameters,
1197     HloCloneContext* context, const std::string& suffix,
1198     const HloInstruction* new_root) {
1199   std::unique_ptr<HloCloneContext> context_ptr;
1200   if (context == nullptr) {
1201     context_ptr = std::make_unique<HloCloneContext>(parent(), suffix);
1202     context = context_ptr.get();
1203   }
1204   if (new_root == nullptr) {
1205     new_root = root_instruction();
1206   }
1207 
1208   // Look up instr in the replacements map, and return either the replacement,
1209   // or instr, if the replacement isn't present.
1210   //
1211   // Note: This can return null, indicating that instr should not be present in
1212   // the new computation.
1213   auto replace = [&](const HloInstruction* instr) {
1214     if (!replacements) return instr;
1215     auto it = replacements->find(instr);
1216     return it != replacements->end() ? it->second.get() : instr;
1217   };
1218 
1219   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1220 
1221   // We want to do a postorder walk over [replace(i) for i in instructions_].
1222   // We can't reuse MakeInstructionPostOrder() for this, because that will
1223   // generate a postorder of plain instructions_, and our replacements may
1224   // change the postorder!
1225   //
1226   // The postorder we want here is simpler than what MakeInstructionPostOrder()
1227   // does -- we only care about operand dependencies -- so let's just do it
1228   // ourselves.
1229   std::vector<const HloInstruction*> postorder;
1230   absl::flat_hash_map<const HloInstruction*, VisitState> visited;
1231   for (const auto& instr : instructions_) {
1232     std::vector<const HloInstruction*> dfs_stack;
1233     const HloInstruction* new_instr = replace(instr.get());
1234     if (!new_instr) {
1235       continue;
1236     }
1237     dfs_stack.push_back(new_instr);
1238 
1239     while (!dfs_stack.empty()) {
1240       auto* cur = dfs_stack.back();
1241       auto it = visited.find(cur);
1242       if (it != visited.end()) {
1243         dfs_stack.pop_back();
1244         if (it->second == kVisited) {
1245           continue;
1246         }
1247         CHECK_EQ(it->second, kVisiting);
1248         postorder.push_back(cur);
1249         it->second = kVisited;
1250         continue;
1251       }
1252 
1253       visited.insert({cur, kVisiting});
1254       for (HloInstruction* operand : cur->operands()) {
1255         const HloInstruction* new_operand = replace(operand);
1256         if (new_operand) {
1257           dfs_stack.emplace_back(new_operand);
1258         }
1259       }
1260     }
1261   }
1262 
1263   std::vector<std::unique_ptr<HloInstruction>> instructions;
1264   // First add the extra parameters to 'instructions'.
1265   for (const auto& instr : extra_parameters) {
1266     CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1267         << "Only parameter instructions are allowed in 'extra_parameters'";
1268     instructions.emplace_back(instr->Clone());
1269   }
1270   for (auto instr : postorder) {
1271     std::vector<HloInstruction*> new_operands;
1272     for (auto operand : instr->operands()) {
1273       auto replaced_operand = replace(operand);
1274       CHECK_NE(replaced_operand, nullptr)
1275           << "replacements map tried to eliminate a used instruction "
1276           << operand->ToString() << ", used by " << instr->ToString();
1277       new_operands.push_back(context->GetInstruction(replaced_operand));
1278     }
1279     std::unique_ptr<HloInstruction> new_instr =
1280         instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1281     if (instr->opcode() == HloOpcode::kParameter &&
1282         instr->parameter_replicated_at_leaf_buffers().has_value()) {
1283       new_instr->set_parameter_replicated_at_leaf_buffers(
1284           instr->parameter_replicated_at_leaf_buffers().value());
1285     }
1286     instructions.push_back(std::move(new_instr));
1287   }
1288 
1289   // To make clone behavior match uncloned behavior, we reorder instructions to
1290   // match the order in instructions_.
1291   SortClonedInstructions(*context, replace, *this, instructions_, instructions);
1292 
1293   Builder builder(suffix.empty() ? name() : name() + "." + suffix);
1294   for (auto& instr : instructions) {
1295     builder.AddInstruction(std::move(instr));
1296   }
1297   auto result = builder.Build(
1298       /*root_instruction=*/context->GetInstruction(replace(new_root)));
1299 
1300   // Clone control dependencies.
1301   for (auto instr : postorder) {
1302     HloInstruction* new_instr = context->GetInstruction(instr);
1303     for (auto successor : instr->control_successors()) {
1304       auto replaced_successor = replace(successor);
1305       // successor may not have been remapped, because it might have been
1306       // removed by the replacements map.
1307       if (replaced_successor != nullptr) {
1308         TF_CHECK_OK(new_instr->AddControlDependencyTo(
1309             context->GetInstruction(replaced_successor)));
1310       }
1311     }
1312   }
1313 
1314   // To make clone behavior match uncloned behavior, we reorder the user and
1315   // control lists, kept by cloned instructions.
1316   SortClonedInstructionUsersAndControlLists(*context, replace, instructions_);
1317 
1318   context->MapComputation(this, result.get());
1319   result->SetExecutionThread(execution_thread());
1320 
1321   return result;
1322 }
1323 
UniquifyName(NameUniquer * name_uniquer)1324 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1325   name_ = name_uniquer->GetUniqueName(name_);
1326 }
1327 
GetInstructionWithName(absl::string_view name)1328 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1329   auto instructions_in_computation = instructions();
1330   auto it = absl::c_find_if(
1331       instructions_in_computation,
1332       [&](HloInstruction* instr) { return instr->name() == name; });
1333   return it == instructions_in_computation.end() ? nullptr : *it;
1334 }
1335 
IsEntryComputation() const1336 bool HloComputation::IsEntryComputation() const {
1337   return parent()->entry_computation() == this;
1338 }
1339 }  // namespace xla
1340