• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 
49 using absl::StrCat;
50 
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52     HloInstruction* root_instruction) {
53   int parameter_count = 0;
54   for (auto& instruction : instructions_) {
55     if (instruction->opcode() == HloOpcode::kParameter) {
56       parameter_count++;
57     }
58   }
59   // If root_instruction is not specified use the last added instruction.
60   HloInstruction* root =
61       root_instruction ? root_instruction : last_added_instruction_;
62   CHECK_NE(nullptr, root);
63   return absl::WrapUnique(new HloComputation(
64       name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66 
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68     const string& name, int parameter_count,
69     std::vector<std::unique_ptr<HloInstruction>>* instructions,
70     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71     : name_(NameUniquer::GetSanitizedName(name)),
72       unique_id_(-1),
73       root_instruction_(root_instruction),
74       fusion_instruction_(fusion_instruction) {
75   param_instructions_.resize(parameter_count, nullptr);
76   bool root_found = false;
77   for (auto& instruction : *instructions) {
78     if (instruction->opcode() == HloOpcode::kParameter) {
79       int64 param_no = instruction->parameter_number();
80       CHECK(param_no >= 0 && param_no < parameter_count)
81           << "\nERROR: invalid parameter number.  Expected [0, "
82           << parameter_count << "), got " << param_no;
83       CHECK(param_instructions_[param_no] == nullptr)
84           << "\nERROR: parameter number " << param_no
85           << " already allocated in this computation";
86       param_instructions_[param_no] = instruction.get();
87     }
88     root_found |= instruction.get() == root_instruction_;
89     AddInstructionInternal(std::move(instruction));
90   }
91   CHECK(root_found)
92       << "\nERROR: root instruction is not present in computation.";
93 }
94 
AddInstruction(std::unique_ptr<HloInstruction> instruction)95 HloInstruction* HloComputation::AddInstruction(
96     std::unique_ptr<HloInstruction> instruction) {
97   CHECK(instruction->opcode() != HloOpcode::kParameter)
98       << "Parameter instructions cannot be added to a computation after "
99       << "it has been built";
100   return AddInstructionInternal(std::move(instruction));
101 }
102 
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)103 HloInstruction* HloComputation::AddInstructionInternal(
104     std::unique_ptr<HloInstruction> instruction) {
105   if (parent() != nullptr) {
106     instruction->UniquifyName(&parent()->instruction_name_uniquer());
107     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
108   }
109   instruction->set_parent(this);
110   HloInstruction* pinst = instruction.get();
111   instruction_iterators_[pinst] =
112       instructions_.insert(instructions_.end(), std::move(instruction));
113   return pinst;
114 }
115 
AddParameter(std::unique_ptr<HloInstruction> instruction)116 HloInstruction* HloComputation::AddParameter(
117     std::unique_ptr<HloInstruction> instruction) {
118   CHECK(instruction->opcode() == HloOpcode::kParameter);
119   CHECK(IsFusionComputation());
120   CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
121   instruction->set_parent(this);
122   param_instructions_.push_back(instruction.get());
123   AddInstructionInternal(std::move(instruction));
124   return instructions_.back().get();
125 }
126 
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)127 HloInstruction* HloComputation::AddEntryComputationParameter(
128     std::unique_ptr<HloInstruction> instruction) {
129   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
130   CHECK_EQ(instruction->parameter_number(), num_parameters());
131   CHECK(parent()->entry_computation() == this);
132 
133   HloModuleConfig config = parent()->config();
134   config.mutable_entry_computation_layout()->add_parameter_layout(
135       ShapeLayout(instruction->shape()));
136   parent()->set_config(config);
137 
138   instruction->set_parent(this);
139   param_instructions_.push_back(instruction.get());
140   AddInstructionInternal(std::move(instruction));
141 
142   return instructions_.back().get();
143 }
144 
RemoveParameter(int64 param_no)145 Status HloComputation::RemoveParameter(int64 param_no) {
146   CHECK_GE(param_no, 0);
147   CHECK_LT(param_no, param_instructions_.size());
148   CHECK(IsFusionComputation());
149   HloInstruction* param_instruction = param_instructions_[param_no];
150   auto param_instruction_iterator = param_instructions_.begin() + param_no;
151   param_instructions_.erase(param_instruction_iterator);
152   // Throw removed fused parameter instruction away.
153   TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
154 
155   while (param_no < param_instructions_.size()) {
156     param_instruction = param_instructions_[param_no];
157     HloInstruction* new_instr =
158         AddInstructionInternal(HloInstruction::CreateParameter(
159             param_no, param_instruction->shape(), StrCat("param_", param_no)));
160     TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
161     param_instructions_[param_no] = new_instr;
162     TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
163     param_no++;
164   }
165 
166   return Status::OK();
167 }
168 
RemoveUnusedParameters()169 Status HloComputation::RemoveUnusedParameters() {
170   CHECK(IsFusionComputation());
171   int64 removed = 0;
172   for (int64 i = 0; i < param_instructions_.size(); ++i) {
173     HloInstruction* param_instruction = param_instructions_[i];
174     if (param_instruction->user_count() == 0 &&
175         param_instruction != root_instruction()) {
176       TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
177       ++removed;
178       continue;
179     }
180 
181     if (removed > 0) {
182       const int64 param_no = i - removed;
183       HloInstruction* new_instr = AddInstructionInternal(
184           HloInstruction::CreateParameter(param_no, param_instruction->shape(),
185                                           StrCat("param_", param_no)));
186       TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
187       param_instructions_[param_no] = new_instr;
188       TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
189     }
190   }
191   param_instructions_.resize(param_instructions_.size() - removed);
192   return Status::OK();
193 }
194 
IsRemovable(const HloInstruction * instruction)195 bool HloComputation::IsRemovable(const HloInstruction* instruction) {
196   // If the instruction has control predecessors or successors then we cannot
197   // remove the instruction without violating ordering constraints (added, for
198   // example, to avert interference due to buffer aliasing).
199   if (!instruction->control_predecessors().empty() ||
200       !instruction->control_successors().empty()) {
201     return false;
202   }
203 
204   if (instruction->opcode() == HloOpcode::kParameter &&
205       !IsFusionComputation()) {
206     return false;
207   }
208 
209   return true;
210 }
211 
HasSideEffect() const212 bool HloComputation::HasSideEffect() const {
213   for (auto* instruction : instructions()) {
214     if (instruction->HasSideEffect()) {
215       return true;
216     }
217   }
218   return false;
219 }
220 
RemoveInstructionAndUnusedOperands(HloInstruction * instruction)221 Status HloComputation::RemoveInstructionAndUnusedOperands(
222     HloInstruction* instruction) {
223   TF_RET_CHECK(root_instruction() != instruction);
224 
225   TF_RET_CHECK(instruction->user_count() == 0);
226   TF_RET_CHECK(IsRemovable(instruction))
227       << "Cannot remove instruction: " << instruction->ToString();
228   absl::flat_hash_set<HloInstruction*> removed;
229   std::queue<HloInstruction*> worklist;
230   worklist.push(instruction);
231   while (!worklist.empty()) {
232     HloInstruction* item = worklist.front();
233     worklist.pop();
234 
235     if (removed.contains(item) || item->user_count() != 0 ||
236         item == root_instruction() || !IsRemovable(item) ||
237         (item->HasSideEffect() && item != instruction)) {
238       continue;
239     }
240     for (int i = 0; i < item->operand_count(); ++i) {
241       worklist.push(item->mutable_operand(i));
242     }
243 
244     TF_RETURN_IF_ERROR(RemoveInstruction(item));
245     removed.insert(item);
246   }
247   return Status::OK();
248 }
249 
RemoveInstruction(HloInstruction * instruction)250 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
251   VLOG(2) << "Removing instruction " << instruction->name()
252           << " from computation " << name();
253   TF_RET_CHECK(IsRemovable(instruction))
254       << "cannot remove instruction: " << instruction->ToString();
255   TF_RET_CHECK(root_instruction() != instruction)
256       << "cannot remove root instruction " << instruction->name();
257   TF_RET_CHECK(instruction->user_count() == 0)
258       << "instruction " << instruction->name()
259       << " has users and cannot be removed";
260   TF_RET_CHECK(instruction->control_predecessors().empty())
261       << "instruction " << instruction->name()
262       << " has control predecessors and cannot be removed";
263   TF_RET_CHECK(instruction->control_successors().empty())
264       << "instruction " << instruction->name()
265       << " has control successors and cannot be removed";
266 
267   auto inst_it = instruction_iterators_.find(instruction);
268   TF_RET_CHECK(inst_it != instruction_iterators_.end());
269   (*inst_it->second)->set_parent(nullptr);
270   instructions_.erase(inst_it->second);
271   instruction_iterators_.erase(inst_it);
272   return Status::OK();
273 }
274 
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)275 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
276                                           bool accept_different_shape) {
277   // The shape of the root (ignoring layout) is an invariant of the computation
278   // for non-fusion cases.
279   if (!IsFusionComputation() && !accept_different_shape) {
280     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
281                                 root_instruction_->shape()))
282         << new_root_instruction->shape() << " is incompatible with "
283         << root_instruction_->shape();
284   }
285   bool root_found = false;
286   for (auto& instruction : instructions_) {
287     if (new_root_instruction == instruction.get()) {
288       root_found = true;
289       break;
290     }
291   }
292   DCHECK(root_found);
293 
294   root_instruction_ = new_root_instruction;
295 }
296 
297 namespace {
298 
299 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)300 void ComputeComputationPostOrder(HloComputation* computation,
301                                  absl::flat_hash_set<HloComputation*>* visited,
302                                  std::vector<HloComputation*>* post_order) {
303   if (visited->insert(computation).second) {
304     for (auto* instruction : computation->instructions()) {
305       for (HloComputation* called_computation :
306            instruction->called_computations()) {
307         ComputeComputationPostOrder(called_computation, visited, post_order);
308       }
309     }
310     post_order->push_back(computation);
311   }
312 }
313 
314 }  // namespace
315 
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const316 void HloComputation::ComputeInstructionPostOrder(
317     const HloComputation::ChannelDependencyGroup& channel_dependency_group,
318     std::vector<HloInstruction*>* post_order, HloInstruction* root,
319     absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
320   std::vector<HloInstruction*> dfs_stack;
321   dfs_stack.push_back(root);
322   while (!dfs_stack.empty()) {
323     const auto current = dfs_stack.back();
324     auto it = visited->find(current);
325     if (it != visited->end()) {
326       if (it->second == kVisited) {
327         // Already visited.
328         dfs_stack.pop_back();
329         continue;
330       }
331       // Visit this node.
332       CHECK_EQ(kVisiting, it->second);
333       dfs_stack.pop_back();
334       post_order->push_back(current);
335       it->second = kVisited;
336       continue;
337     }
338 
339     visited->insert({current, kVisiting});
340 
341     const auto get_channel_id =
342         [](HloInstruction* inst) -> absl::optional<int64> {
343       switch (inst->opcode()) {
344         case HloOpcode::kRecvDone:
345           return inst->channel_id();
346         case HloOpcode::kAllReduce:
347           return inst->all_reduce_id();
348         default:
349           return absl::nullopt;
350       }
351     };
352 
353     // When adding a predecessor to the dfs_stack, we need to also add its
354     // associated channel dependencies.
355     const auto add_dfs_stack = [&](HloInstruction* inst) {
356       auto channel_id = get_channel_id(inst);
357       if (channel_id && channel_dependency_group.count(*channel_id)) {
358         auto it = channel_dependency_group.find(*channel_id);
359         for (HloInstruction* cinst : it->second) {
360           dfs_stack.emplace_back(cinst);
361         }
362       } else {
363         dfs_stack.emplace_back(inst);
364       }
365     };
366 
367     const auto add_predecessors = [&](HloInstruction* inst) {
368       // Add the operands to the stack in reverse order so the first operand is
369       // processed first. This will produce a more natural ordering and a nicer
370       // result for things like HLO stringification.
371       const auto& operands = inst->operands();
372       for (int64 i = operands.size() - 1; i >= 0; --i) {
373         add_dfs_stack(operands[i]);
374       }
375 
376       for (HloInstruction* op : inst->control_predecessors()) {
377         add_dfs_stack(op);
378       }
379     };
380 
381     // If the current instruction is a channel instruction, add the dependencies
382     // from all associated instructions of the channel.
383     auto channel_id = get_channel_id(current);
384     if (channel_id && channel_dependency_group.count(*channel_id)) {
385       auto it = channel_dependency_group.find(*channel_id);
386       for (HloInstruction* cinst : it->second) {
387         add_predecessors(cinst);
388       }
389     } else {
390       add_predecessors(current);
391     }
392   }
393 }
394 
395 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const396 HloComputation::ComputeChannelDependencies() const {
397   ChannelDependencyGroup channel_dependency_group;
398   for (const auto& instruction : instructions_) {
399     switch (instruction->opcode()) {
400       case HloOpcode::kSend:
401       case HloOpcode::kRecvDone:
402         channel_dependency_group[instruction->channel_id()].push_back(
403             instruction.get());
404         break;
405       case HloOpcode::kAllReduce: {
406         auto all_reduce_id = instruction->all_reduce_id();
407         if (all_reduce_id) {
408           channel_dependency_group[all_reduce_id.value()].push_back(
409               instruction.get());
410         }
411         break;
412       }
413       default:
414         break;
415     }
416   }
417   return channel_dependency_group;
418 }
419 
MakeInstructionPostOrder() const420 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
421   auto channel_dependency_group = ComputeChannelDependencies();
422   std::vector<HloInstruction*> post_order;
423   post_order.reserve(instruction_count());
424   std::vector<HloInstruction*> trace_instructions;
425   absl::flat_hash_map<HloInstruction*, VisitState> visited;
426   visited.reserve(instruction_count());
427   for (auto& instruction : instructions_) {
428     if (instruction->opcode() == HloOpcode::kTrace) {
429       // Trace instructions aren't handled by the DFS visitor. Add trace
430       // instructions to the post order at the end (necessarily they have no
431       // users).
432       trace_instructions.push_back(instruction.get());
433     } else if (instruction->users().empty()) {
434       ComputeInstructionPostOrder(channel_dependency_group, &post_order,
435                                   instruction.get(), &visited);
436     }
437   }
438   post_order.insert(post_order.end(), trace_instructions.begin(),
439                     trace_instructions.end());
440   CHECK_EQ(instructions_.size(), post_order.size())
441       << "number of instructions does not match post order size";
442   return post_order;
443 }
444 
MakeEmbeddedComputationsList() const445 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
446     const {
447   absl::flat_hash_set<HloComputation*> visited;
448   std::vector<HloComputation*> post_order;
449 
450   // To avoid special handling of this computation, cast away const of
451   // 'this'. 'this' is immediately removed from the post order after
452   // construction.
453   //
454   // TODO(b/78350259): This violates const-correctness, since while the original
455   // computation is not returned, we still retrieve non-const computations from
456   // a const one. Consider also avoiding const for HloComputation, or review XLA
457   // for const-correctness of non-HloInstruction* types like this.
458   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
459                               &post_order);
460 
461   // We don't want to include this computation in the post order.
462   CHECK_EQ(this, post_order.back());
463   post_order.pop_back();
464 
465   return post_order;
466 }
467 
ToString(const HloPrintOptions & options) const468 string HloComputation::ToString(const HloPrintOptions& options) const {
469   return ToString(options, MakeInstructionPostOrder());
470 }
471 
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const472 string HloComputation::ToString(
473     const HloPrintOptions& options,
474     absl::Span<const HloInstruction* const> instruction_order) const {
475   CHECK_EQ(instruction_order.size(), instruction_count());
476 
477   std::ostringstream s;
478   for (int i = 0; i < options.indent_amount(); i++) {
479     s << "  ";
480   }
481 
482   if (!options.is_in_nested_computation()) {
483     if (options.print_percent()) {
484       s << "%";
485     }
486     s << name() << " ";
487   }
488 
489   if (options.print_program_shape()) {
490     s << ShapeUtil::HumanString(ComputeProgramShape()) << " ";
491   }
492   s << "{\n";
493   {
494     // Print the instructions in this computation.
495     HloPrintOptions new_options = options;
496     new_options.set_indent_amount(options.indent_amount() + 1)
497         .set_is_in_nested_computation(true);
498     CanonicalNameMap name_map;
499     for (const HloInstruction* instruction : instruction_order) {
500       CHECK_EQ(this, instruction->parent());
501 
502       for (int i = 0; i < new_options.indent_amount(); i++) {
503         s << "  ";
504       }
505       s << (instruction == root_instruction_ ? "ROOT " : "")
506         << instruction->ToStringWithCanonicalNameMap(new_options, &name_map)
507         << "\n";
508     }
509   }
510 
511   for (int i = 0; i < options.indent_amount(); i++) {
512     s << "  ";
513   }
514   s << "}";
515   return s.str();
516 }
517 
ToProto() const518 HloComputationProto HloComputation::ToProto() const {
519   HloComputationProto proto;
520   CHECK(unique_id_ != -1)
521       << "This computation does not have a valid id. Please make sure the "
522          "computation is inside a module before dumping it.";
523   proto.set_id(unique_id_);
524   proto.set_name(name_);
525   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
526     HloInstructionProto instruction_proto = instruction->ToProto();
527     proto.add_instructions()->Swap(&instruction_proto);
528   }
529   proto.set_root_id(root_instruction()->unique_id());
530   *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
531   return proto;
532 }
533 
534 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map)535 HloComputation::CreateFromProto(
536     const HloComputationProto& proto,
537     const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
538   absl::flat_hash_map<int64, HloInstruction*> instruction_map;
539   absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
540   std::vector<std::unique_ptr<HloInstruction>> instructions;
541   int64 parameter_count = 0;
542   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
543     TF_ASSIGN_OR_RETURN(
544         std::unique_ptr<HloInstruction> instruction,
545         HloInstruction::CreateFromProto(instruction_proto, instruction_map,
546                                         computation_map));
547     if (instruction->opcode() == HloOpcode::kParameter) {
548       parameter_count++;
549     }
550     TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
551     instruction_map[instruction_proto.id()] = instruction.get();
552     to_proto_id[instruction.get()] = instruction_proto.id();
553     instructions.push_back(std::move(instruction));
554   }
555 
556   TF_RET_CHECK(proto.root_id() != -1);
557   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
558   HloInstruction* root = instruction_map.at(proto.root_id());
559 
560   // Sort the instructions in the proto id's order.
561   absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
562                                  const std::unique_ptr<HloInstruction>& b) {
563     return to_proto_id[a.get()] < to_proto_id[b.get()];
564   });
565 
566   TF_RETURN_IF_ERROR([&]() -> Status {
567     std::vector<bool> parameters_seen(parameter_count);
568     int parameters_seen_count = 0;
569     for (auto& instruction : instructions) {
570       if (instruction->opcode() == HloOpcode::kParameter) {
571         int64 param_no = instruction->parameter_number();
572         TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
573             << "Invalid parameter number.  Expected [0, " << parameter_count
574             << "), got " << param_no;
575         TF_RET_CHECK(!parameters_seen[param_no])
576             << "Parameter number " << param_no
577             << " already allocated in this computation";
578         parameters_seen[param_no] = true;
579         parameters_seen_count++;
580       }
581     }
582     TF_RET_CHECK(parameters_seen_count == parameter_count)
583         << "Not all parameters in range [0, " << parameter_count
584         << ") were referenced";
585     return Status::OK();
586   }());
587 
588   auto computation = absl::WrapUnique(
589       new HloComputation(proto.name(), parameter_count, &instructions, root,
590                          /*fusion_instruction=*/nullptr));
591   computation->unique_id_ = proto.id();
592   return std::move(computation);
593 }
594 
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)595 void HloComputation::FuseInstructionsInto(
596     absl::Span<HloInstruction* const> instructions_to_fuse,
597     HloInstruction* fusion_instruction) {
598   CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
599   HloInstruction* root = instructions_to_fuse.front();
600   TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
601   if (root == root_instruction()) {
602     set_root_instruction(fusion_instruction);
603   }
604   TF_CHECK_OK(RemoveInstruction(root));
605   for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
606     HloInstruction* instruction = instructions_to_fuse[i];
607     fusion_instruction->FuseInstruction(instruction);
608     if (instruction->user_count() == 0) {
609       TF_CHECK_OK(RemoveInstruction(instruction));
610     }
611   }
612 }
613 
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)614 HloInstruction* HloComputation::CreateFusionInstruction(
615     absl::Span<HloInstruction* const> instructions_to_fuse,
616     HloInstruction::FusionKind fusion_kind) {
617   HloInstruction* root = instructions_to_fuse.front();
618   HloInstruction* fusion_instruction = AddInstruction(
619       HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
620   FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
621   return fusion_instruction;
622 }
623 
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)624 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
625     HloInstruction* instruction, ShapeIndex* index,
626     const std::function<
627         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
628                         HloComputation* computation)>& copy_leaf) {
629   if (instruction->shape().IsTuple()) {
630     std::vector<HloInstruction*> elements;
631     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
632          i++) {
633       HloInstruction* gte =
634           AddInstruction(HloInstruction::CreateGetTupleElement(
635               ShapeUtil::GetTupleElementShape(instruction->shape(), i),
636               instruction, i));
637 
638       index->push_back(i);
639       TF_ASSIGN_OR_RETURN(HloInstruction * element,
640                           DeepCopyHelper(gte, index, copy_leaf));
641       elements.push_back(element);
642       index->pop_back();
643     }
644     return AddInstruction(HloInstruction::CreateTuple(elements));
645   }
646   if (instruction->shape().IsToken()) {
647     // Tokens have no on-device representation and cannot be copied. Pass
648     // through transparently.
649     return instruction;
650   }
651 
652   // Array shape.
653   TF_RET_CHECK(instruction->shape().IsArray());
654   return copy_leaf(instruction, *index, this);
655 }
656 
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)657 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
658     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
659     ShapeTree<HloInstruction*>* copies_added) {
660   if (instruction->parent() != this) {
661     return FailedPrecondition(
662         "Can't deep copy instruction %s: instruction is not in computation %s",
663         instruction->name(), name());
664   }
665   if (indices_to_copy != nullptr &&
666       !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
667     return FailedPrecondition(
668         "Can't deep copy instruction %s: given shape tree of indices to copy "
669         "has incompatible shapes: %s vs. %s",
670         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
671         ShapeUtil::HumanString(indices_to_copy->shape()));
672   }
673 
674   ShapeIndex index;
675   auto copy_leaf = [indices_to_copy, copies_added](
676                        HloInstruction* leaf, const ShapeIndex& leaf_index,
677                        HloComputation* computation) {
678     if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
679       HloInstruction* copy = computation->AddInstruction(
680           HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
681       if (copies_added != nullptr) {
682         *copies_added->mutable_element(leaf_index) = copy;
683       }
684       return copy;
685     }
686     // Elements which are not to be copied are passed through
687     // transparently.
688     return leaf;
689   };
690   return DeepCopyHelper(instruction, &index, copy_leaf);
691 }
692 
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)693 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
694     HloInstruction* instruction,
695     const std::function<
696         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
697                         HloComputation* computation)>& copy_leaf) {
698   if (instruction->parent() != this) {
699     return FailedPrecondition(
700         "Can't deep copy instruction %s: instruction is not in computation %s",
701         instruction->name(), name());
702   }
703   ShapeIndex index;
704   return DeepCopyHelper(instruction, &index, copy_leaf);
705 }
706 
ComputeProgramShape() const707 ProgramShape HloComputation::ComputeProgramShape() const {
708   ProgramShape program_shape;
709 
710   for (auto* param_instruction : param_instructions_) {
711     *program_shape.add_parameters() = param_instruction->shape();
712     *program_shape.add_parameter_names() = param_instruction->name();
713   }
714   *program_shape.mutable_result() = root_instruction_->shape();
715 
716   return program_shape;
717 }
718 
operator ==(const HloComputation & other) const719 bool HloComputation::operator==(const HloComputation& other) const {
720   if (this == &other) {
721     return true;
722   }
723   absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
724       visited;
725   std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
726 
727   worklist.push_back({root_instruction(), other.root_instruction()});
728 
729   while (!worklist.empty()) {
730     auto pair = worklist.back();
731     worklist.pop_back();
732 
733     if (visited.contains(pair)) {
734       continue;
735     }
736     visited.emplace(pair);
737     // TODO(b/123082518): Avoid recursively invoking == becasue it may
738     // cause a stack overflow with deeply nested subcomputations.
739     bool identical_ignoring_operands = pair.first->Identical(
740         *pair.second,
741         [](const HloInstruction*, const HloInstruction*) { return true; },
742         [](const HloComputation* a, const HloComputation* b) {
743           return *a == *b;
744         });
745     if (!identical_ignoring_operands) {
746       return false;
747     }
748     for (size_t i = 0; i < pair.first->operands().size(); ++i) {
749       worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
750     }
751   }
752   return true;
753 }
754 
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)755 Status HloComputation::ReplaceWithNewInstruction(
756     HloInstruction* old_instruction,
757     std::unique_ptr<HloInstruction> new_instruction) {
758   return ReplaceInstruction(old_instruction,
759                             AddInstruction(std::move(new_instruction)));
760 }
761 
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)762 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
763                                           HloInstruction* new_instruction) {
764   TF_RET_CHECK(
765       ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
766       << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
767       << ShapeUtil::HumanString(new_instruction->shape());
768 
769   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
770            << new_instruction->ToString();
771   // Try to add metadata for HLO instructions that are created to replace
772   // existing HLO instructions (e.g. during optimizations). The assumption is
773   // that the old instruction and the new instruction would perform the same
774   // function, and that they would be correlated to the same TF op. This might
775   // not always be correct since HLO optimizations can cross TF op boundaries.
776   // But still this seems to be better than nothing.
777   if (new_instruction->metadata().op_name().empty()) {
778     new_instruction->set_metadata(old_instruction->metadata());
779   }
780   TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
781   return RemoveInstructionAndUnusedOperands(old_instruction);
782 }
783 
CollectUnreachableRoots() const784 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
785   std::vector<HloInstruction*> unreachable_roots;
786   for (auto* instruction : instructions()) {
787     if (instruction->user_count() == 0 &&
788         instruction->control_successors().empty() &&
789         instruction != root_instruction()) {
790       unreachable_roots.push_back(instruction);
791     }
792   }
793   VLOG(3) << "Unreachable roots:"
794           << absl::StrJoin(unreachable_roots, "\n\t",
795                            [](string* out, const HloInstruction* hlo) {
796                              absl::StrAppend(out, hlo->ToString());
797                            });
798   return unreachable_roots;
799 }
800 
801 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor) const802 Status HloComputation::Accept(
803     DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
804   // Visit unreachable roots. Beware that the visitor might delete the currently
805   // visited root, which would invalidate iterators if the unreachable roots
806   // weren't computed ahead of time.
807   for (HloInstruction* root : CollectUnreachableRoots()) {
808     VLOG(3) << "Traversing unreachable root: " << root->ToString();
809     // Call FinishVisit only at the end.
810     TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
811   }
812   // Visit the computation root instruction last.
813   return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
814 }
815 
816 // Explicit instantiations.
817 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
818 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
819 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const820 Status HloComputation::AcceptWithOperandOrder(
821     DfsHloVisitor* visitor,
822     const HloInstruction::CompareFunction& operand_order) const {
823   // Visit unreachable roots. Beware that the visitor might delete the currently
824   // visited root, which would invalidate iterators if the unreachable roots
825   // weren't computed ahead of time.
826   for (HloInstruction* root : CollectUnreachableRoots()) {
827     TF_RETURN_IF_ERROR(
828         root->AcceptWithOperandOrder(visitor, operand_order,
829                                      /*call_finish_visit=*/false));
830   }
831   // Visit the computation root instruction last.
832   return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
833                                                     /*call_finish_visit=*/true);
834 }
835 
836 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order) const837 Status HloComputation::AcceptOrdered(
838     DfsHloVisitorBase<HloInstructionPtr>* visitor,
839     absl::Span<HloInstruction* const> order) const {
840   VLOG(3) << "Accepting visitor with order.";
841   for (HloInstruction* root : CollectUnreachableRoots()) {
842     TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
843   }
844   TF_RET_CHECK(order.size() == instruction_count());
845   absl::flat_hash_set<const HloInstruction*> visited;
846   for (const HloInstruction* instruction : order) {
847     VLOG(3) << "Visiting ordered: " << instruction->ToString();
848     TF_RET_CHECK(instruction_iterators_.contains(instruction))
849         << "Instruction " << instruction->name() << " is not in computation "
850         << name();
851     TF_RET_CHECK(!visited.contains(instruction))
852         << "Instruction " << instruction->name()
853         << " appears more than once in order";
854     HloInstruction* mutable_instruction =
855         const_cast<HloInstruction*>(instruction);
856     TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
857     TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
858     visitor->SetVisited(*mutable_instruction);
859     TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
860     visited.insert(instruction);
861   }
862   TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
863   return Status::OK();
864 }
865 
866 // Explicit instantiations.
867 template Status HloComputation::AcceptOrdered(
868     DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
869 template Status HloComputation::AcceptOrdered(
870     ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
871 
Accept(const std::function<Status (HloInstruction *)> & visitor_func)872 Status HloComputation::Accept(
873     const std::function<Status(HloInstruction*)>& visitor_func) {
874   FunctionVisitor visitor(visitor_func);
875   return this->Accept(&visitor);
876 }
877 
Accept(const std::function<Status (const HloInstruction *)> & visitor_func) const878 Status HloComputation::Accept(
879     const std::function<Status(const HloInstruction*)>& visitor_func) const {
880   ConstFunctionVisitor visitor(visitor_func);
881   return this->Accept(&visitor);
882 }
883 
Clone(const string & suffix,HloCloneContext * context)884 std::unique_ptr<HloComputation> HloComputation::Clone(
885     const string& suffix, HloCloneContext* context) {
886   return CloneWithReplacements(
887       /*replacements=*/absl::flat_hash_map<const HloInstruction*,
888                                            std::unique_ptr<HloInstruction>>(),
889       /*extra_parameters=*/{}, context, suffix);
890 }
891 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)892 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
893     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
894     HloCloneContext* context, const string& suffix) {
895   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
896       replacements;
897   replacements.emplace(std::move(r1));
898   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
899                                context, suffix);
900 }
901 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)902 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
903     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
904     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
905     HloCloneContext* context, const string& suffix) {
906   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
907       replacements;
908   replacements.emplace(std::move(r1));
909   replacements.emplace(std::move(r2));
910   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
911                                context, suffix);
912 }
913 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)914 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
915     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
916     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
917     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
918     HloCloneContext* context, const string& suffix) {
919   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
920       replacements;
921   replacements.emplace(std::move(r1));
922   replacements.emplace(std::move(r2));
923   replacements.emplace(std::move(r3));
924   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
925                                context, suffix);
926 }
927 
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix)928 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
929     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
930         replacements,
931     absl::Span<const HloInstruction* const> extra_parameters,
932     HloCloneContext* context, const string& suffix) {
933   std::unique_ptr<HloCloneContext> context_ptr;
934   if (context == nullptr) {
935     context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
936     context = context_ptr.get();
937   }
938 
939   // Look up instr in the replacements map, and return either the replacement,
940   // or instr, if the replacement isn't present.
941   //
942   // Note: This can return null, indicating that instr should not be present in
943   // the new computation.
944   auto replace = [&](HloInstruction* instr) {
945     auto it = replacements.find(instr);
946     if (it == replacements.end()) {
947       return instr;
948     }
949     return it->second.get();
950   };
951 
952   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
953 
954   // We want to do a postorder walk over [replace(i) for i in instructions_].
955   // We can't reuse MakeInstructionPostOrder() for this, because that will
956   // generate a postorder of plain instructions_, and our replacements may
957   // change the postorder!
958   //
959   // The postorder we want here is simpler than what MakeInstructionPostOrder()
960   // does -- we only care about operand dependencies -- so let's just do it
961   // ourselves.
962   std::vector<HloInstruction*> postorder;
963   absl::flat_hash_map<HloInstruction*, VisitState> visited;
964   for (const auto& instr : instructions_) {
965     std::vector<HloInstruction*> dfs_stack;
966     HloInstruction* new_instr = replace(instr.get());
967     if (!new_instr) {
968       continue;
969     }
970     dfs_stack.push_back(new_instr);
971 
972     while (!dfs_stack.empty()) {
973       auto* cur = dfs_stack.back();
974       auto it = visited.find(cur);
975       if (it != visited.end()) {
976         dfs_stack.pop_back();
977         if (it->second == kVisited) {
978           continue;
979         }
980         CHECK_EQ(it->second, kVisiting);
981         postorder.push_back(cur);
982         it->second = kVisited;
983         continue;
984       }
985 
986       visited.insert({cur, kVisiting});
987       for (HloInstruction* operand : cur->operands()) {
988         HloInstruction* new_operand = replace(operand);
989         if (new_operand) {
990           dfs_stack.emplace_back(new_operand);
991         }
992       }
993     }
994   }
995 
996   std::vector<std::unique_ptr<HloInstruction>> instructions;
997   // First add the extra parameters to 'instructions'.
998   for (const auto& instr : extra_parameters) {
999     CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1000         << "Only parameter instructions are allowed in 'extra_parameters'";
1001     instructions.emplace_back(instr->Clone());
1002   }
1003   for (auto instr : postorder) {
1004     std::vector<HloInstruction*> new_operands;
1005     for (auto operand : instr->operands()) {
1006       auto replaced_operand = replace(operand);
1007       CHECK_NE(replaced_operand, nullptr)
1008           << "replacements map tried to eliminate a used instruction "
1009           << operand->ToString() << ", used by " << instr->ToString();
1010       new_operands.push_back(context->GetInstruction(replaced_operand));
1011     }
1012     instructions.push_back(
1013         instr->CloneWithNewOperands(instr->shape(), new_operands, context));
1014   }
1015   Builder builder(name() + "." + suffix);
1016   for (auto& instr : instructions) {
1017     builder.AddInstruction(std::move(instr));
1018   }
1019   auto result = builder.Build(
1020       /*root_instruction=*/context->GetInstruction(
1021           replace(root_instruction())));
1022 
1023   // Clone control dependencies.
1024   for (auto instr : postorder) {
1025     HloInstruction* new_instr = context->GetInstruction(instr);
1026     for (auto successor : instr->control_successors()) {
1027       auto replaced_successor = replace(successor);
1028       // successor may not have been remapped, because it might have been
1029       // removed by the replacements map.
1030       if (replaced_successor != nullptr) {
1031         TF_CHECK_OK(new_instr->AddControlDependencyTo(
1032             context->GetInstruction(replaced_successor)));
1033       }
1034     }
1035   }
1036   context->MapComputation(this, result.get());
1037   return result;
1038 }
1039 
UniquifyName(NameUniquer * name_uniquer)1040 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1041   name_ = name_uniquer->GetUniqueName(name_);
1042 }
1043 
GetInstructionWithName(absl::string_view name)1044 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1045   auto instructions_in_computation = instructions();
1046   auto it = absl::c_find_if(
1047       instructions_in_computation,
1048       [&](HloInstruction* instr) { return instr->name() == name; });
1049   return it == instructions_in_computation.end() ? nullptr : *it;
1050 }
1051 
1052 }  // namespace xla
1053