1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48
49 using absl::StrCat;
50
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52 HloInstruction* root_instruction) {
53 int parameter_count = 0;
54 for (auto& instruction : instructions_) {
55 if (instruction->opcode() == HloOpcode::kParameter) {
56 parameter_count++;
57 }
58 }
59 // If root_instruction is not specified use the last added instruction.
60 HloInstruction* root =
61 root_instruction ? root_instruction : last_added_instruction_;
62 CHECK_NE(nullptr, root);
63 return absl::WrapUnique(new HloComputation(
64 name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68 const string& name, int parameter_count,
69 std::vector<std::unique_ptr<HloInstruction>>* instructions,
70 HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71 : name_(NameUniquer::GetSanitizedName(name)),
72 unique_id_(-1),
73 root_instruction_(root_instruction),
74 fusion_instruction_(fusion_instruction) {
75 param_instructions_.resize(parameter_count, nullptr);
76 bool root_found = false;
77 for (auto& instruction : *instructions) {
78 if (instruction->opcode() == HloOpcode::kParameter) {
79 int64 param_no = instruction->parameter_number();
80 CHECK(param_no >= 0 && param_no < parameter_count)
81 << "\nERROR: invalid parameter number. Expected [0, "
82 << parameter_count << "), got " << param_no;
83 CHECK(param_instructions_[param_no] == nullptr)
84 << "\nERROR: parameter number " << param_no
85 << " already allocated in this computation";
86 param_instructions_[param_no] = instruction.get();
87 }
88 root_found |= instruction.get() == root_instruction_;
89 AddInstructionInternal(std::move(instruction));
90 }
91 CHECK(root_found)
92 << "\nERROR: root instruction is not present in computation.";
93 }
94
AddInstruction(std::unique_ptr<HloInstruction> instruction)95 HloInstruction* HloComputation::AddInstruction(
96 std::unique_ptr<HloInstruction> instruction) {
97 CHECK(instruction->opcode() != HloOpcode::kParameter)
98 << "Parameter instructions cannot be added to a computation after "
99 << "it has been built";
100 return AddInstructionInternal(std::move(instruction));
101 }
102
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)103 HloInstruction* HloComputation::AddInstructionInternal(
104 std::unique_ptr<HloInstruction> instruction) {
105 if (parent() != nullptr) {
106 instruction->UniquifyName(&parent()->instruction_name_uniquer());
107 instruction->SetUniqueId(parent()->NewUniqueInstructionId());
108 }
109 instruction->set_parent(this);
110 HloInstruction* pinst = instruction.get();
111 instruction_iterators_[pinst] =
112 instructions_.insert(instructions_.end(), std::move(instruction));
113 return pinst;
114 }
115
AddParameter(std::unique_ptr<HloInstruction> instruction)116 HloInstruction* HloComputation::AddParameter(
117 std::unique_ptr<HloInstruction> instruction) {
118 CHECK(instruction->opcode() == HloOpcode::kParameter);
119 CHECK(IsFusionComputation());
120 CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
121 instruction->set_parent(this);
122 param_instructions_.push_back(instruction.get());
123 AddInstructionInternal(std::move(instruction));
124 return instructions_.back().get();
125 }
126
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)127 HloInstruction* HloComputation::AddEntryComputationParameter(
128 std::unique_ptr<HloInstruction> instruction) {
129 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
130 CHECK_EQ(instruction->parameter_number(), num_parameters());
131 CHECK(parent()->entry_computation() == this);
132
133 HloModuleConfig config = parent()->config();
134 config.mutable_entry_computation_layout()->add_parameter_layout(
135 ShapeLayout(instruction->shape()));
136 parent()->set_config(config);
137
138 instruction->set_parent(this);
139 param_instructions_.push_back(instruction.get());
140 AddInstructionInternal(std::move(instruction));
141
142 return instructions_.back().get();
143 }
144
ReplaceEntryComputationParameter(int64 param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)145 Status HloComputation::ReplaceEntryComputationParameter(
146 int64 param_no, HloInstruction* old_instruction,
147 std::unique_ptr<HloInstruction> instruction) {
148 CHECK_GE(param_no, 0);
149 CHECK_LT(param_no, param_instructions_.size());
150 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
151 CHECK(parent()->entry_computation() == this);
152
153 HloModuleConfig config = parent()->config();
154 *config.mutable_entry_computation_layout()->mutable_parameter_layout(
155 param_no) = ShapeLayout(instruction->shape());
156 parent()->set_config(config);
157
158 instruction->set_parent(this);
159 param_instructions_[param_no] = instruction.get();
160 AddInstructionInternal(std::move(instruction));
161
162 return ForceRemoveInstruction(old_instruction);
163 }
164
RemoveParameter(int64 param_no)165 Status HloComputation::RemoveParameter(int64 param_no) {
166 CHECK_GE(param_no, 0);
167 CHECK_LT(param_no, param_instructions_.size());
168 CHECK(IsFusionComputation());
169 HloInstruction* param_instruction = param_instructions_[param_no];
170 auto param_instruction_iterator = param_instructions_.begin() + param_no;
171 param_instructions_.erase(param_instruction_iterator);
172 // Throw removed fused parameter instruction away.
173 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
174
175 while (param_no < param_instructions_.size()) {
176 param_instruction = param_instructions_[param_no];
177 HloInstruction* new_instr =
178 AddInstructionInternal(HloInstruction::CreateParameter(
179 param_no, param_instruction->shape(), StrCat("param_", param_no)));
180 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
181 param_instructions_[param_no] = new_instr;
182 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
183 param_no++;
184 }
185
186 return Status::OK();
187 }
188
RemoveUnusedParametersFromFusedComputation()189 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
190 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
191 }
192
RemoveUnusedParametersFromAnyComputation()193 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
194 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
195 }
196
RemoveUnusedParametersImpl(bool allow_non_fusion)197 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
198 CHECK(allow_non_fusion || IsFusionComputation());
199 int64 removed = 0;
200 for (int64 i = 0; i < param_instructions_.size(); ++i) {
201 HloInstruction* param_instruction = param_instructions_[i];
202 if (param_instruction->user_count() == 0 &&
203 param_instruction != root_instruction()) {
204 TF_RETURN_IF_ERROR(
205 RemoveInstructionImpl(param_instruction, allow_non_fusion));
206 ++removed;
207 continue;
208 }
209
210 if (removed > 0) {
211 const int64 param_no = i - removed;
212 HloInstruction* new_instr = AddInstructionInternal(
213 HloInstruction::CreateParameter(param_no, param_instruction->shape(),
214 StrCat("param_", param_no)));
215 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
216 param_instructions_[param_no] = new_instr;
217 TF_RETURN_IF_ERROR(
218 RemoveInstructionImpl(param_instruction, allow_non_fusion));
219 }
220 }
221 param_instructions_.resize(param_instructions_.size() - removed);
222 return Status::OK();
223 }
224
IsSafelyRemovable(const HloInstruction * instruction)225 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
226 // If the instruction has control predecessors or successors then we cannot
227 // remove the instruction without violating ordering constraints (added, for
228 // example, to avert interference due to buffer aliasing).
229 if (!instruction->control_predecessors().empty() ||
230 !instruction->control_successors().empty()) {
231 return false;
232 }
233
234 if (instruction->opcode() == HloOpcode::kParameter &&
235 !IsFusionComputation()) {
236 return false;
237 }
238
239 return true;
240 }
241
HasSideEffect() const242 bool HloComputation::HasSideEffect() const {
243 for (auto* instruction : instructions()) {
244 if (instruction->HasSideEffect()) {
245 return true;
246 }
247 }
248 return false;
249 }
250
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)251 Status HloComputation::RemoveInstructionAndUnusedOperands(
252 HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
253 TF_RET_CHECK(root_instruction() != instruction);
254
255 TF_RET_CHECK(instruction->user_count() == 0);
256 TF_RET_CHECK(IsSafelyRemovable(instruction))
257 << "Cannot remove instruction: " << instruction->ToString();
258 absl::flat_hash_set<HloInstruction*> removed;
259 std::queue<HloInstruction*> worklist;
260 worklist.push(instruction);
261 while (!worklist.empty()) {
262 HloInstruction* item = worklist.front();
263 worklist.pop();
264
265 if (removed.contains(item) || item->user_count() != 0 ||
266 item == root_instruction() || !IsSafelyRemovable(item) ||
267 (item->HasSideEffect() && item != instruction)) {
268 continue;
269 }
270 for (int i = 0; i < item->operand_count(); ++i) {
271 worklist.push(item->mutable_operand(i));
272 }
273
274 if (cleanup) {
275 cleanup(item);
276 }
277 TF_RETURN_IF_ERROR(RemoveInstruction(item));
278 removed.insert(item);
279 }
280 return Status::OK();
281 }
282
RemoveInstruction(HloInstruction * instruction)283 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
284 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
285 }
286
ForceRemoveInstruction(HloInstruction * instruction)287 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
288 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
289 }
290
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)291 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
292 bool ignore_safety_check) {
293 VLOG(2) << "Removing instruction " << instruction->name()
294 << " from computation " << name();
295 TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
296 << "cannot remove instruction: " << instruction->ToString();
297 TF_RET_CHECK(root_instruction() != instruction)
298 << "cannot remove root instruction " << instruction->name();
299 TF_RET_CHECK(instruction->user_count() == 0)
300 << "instruction " << instruction->name()
301 << " has users and cannot be removed";
302 TF_RET_CHECK(instruction->control_predecessors().empty())
303 << "instruction " << instruction->name()
304 << " has control predecessors and cannot be removed";
305 TF_RET_CHECK(instruction->control_successors().empty())
306 << "instruction " << instruction->name()
307 << " has control successors and cannot be removed";
308
309 auto inst_it = instruction_iterators_.find(instruction);
310 TF_RET_CHECK(inst_it != instruction_iterators_.end());
311 (*inst_it->second)->set_parent(nullptr);
312 instructions_.erase(inst_it->second);
313 instruction_iterators_.erase(inst_it);
314 return Status::OK();
315 }
316
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)317 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
318 bool accept_different_shape) {
319 // The shape of the root (ignoring layout) is an invariant of the computation
320 // for non-fusion cases.
321 if (!IsFusionComputation() && !accept_different_shape) {
322 CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
323 root_instruction_->shape()))
324 << new_root_instruction->shape() << " is incompatible with "
325 << root_instruction_->shape();
326 }
327 bool root_found = false;
328 for (auto& instruction : instructions_) {
329 if (new_root_instruction == instruction.get()) {
330 root_found = true;
331 break;
332 }
333 }
334 DCHECK(root_found);
335
336 if (parent() && parent()->has_entry_computation() &&
337 parent()->entry_computation() == this) {
338 if (!Shape::Equal()(new_root_instruction->shape(),
339 root_instruction_->shape())) {
340 // Rebuild input output alias config now that we have a new output shape.
341 parent()->input_output_alias_config() =
342 HloInputOutputAliasConfig(new_root_instruction->shape());
343 }
344 }
345
346 root_instruction_ = new_root_instruction;
347 }
348
349 namespace {
350
351 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)352 void ComputeComputationPostOrder(HloComputation* computation,
353 absl::flat_hash_set<HloComputation*>* visited,
354 std::vector<HloComputation*>* post_order) {
355 if (visited->insert(computation).second) {
356 for (auto* instruction : computation->instructions()) {
357 for (HloComputation* called_computation :
358 instruction->called_computations()) {
359 ComputeComputationPostOrder(called_computation, visited, post_order);
360 }
361 }
362 post_order->push_back(computation);
363 }
364 }
365
366 } // namespace
367
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const368 void HloComputation::ComputeInstructionPostOrder(
369 const HloComputation::ChannelDependencyGroup& channel_dependency_group,
370 std::vector<HloInstruction*>* post_order, HloInstruction* root,
371 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
372 std::vector<HloInstruction*> dfs_stack;
373 dfs_stack.push_back(root);
374 while (!dfs_stack.empty()) {
375 const auto current = dfs_stack.back();
376 auto it = visited->find(current);
377 if (it != visited->end()) {
378 if (it->second == kVisited) {
379 // Already visited.
380 dfs_stack.pop_back();
381 continue;
382 }
383 // Visit this node.
384 CHECK_EQ(kVisiting, it->second);
385 dfs_stack.pop_back();
386 post_order->push_back(current);
387 it->second = kVisited;
388 continue;
389 }
390
391 visited->insert({current, kVisiting});
392
393 const auto get_channel_id =
394 [](HloInstruction* inst) -> absl::optional<int64> {
395 switch (inst->opcode()) {
396 case HloOpcode::kRecvDone:
397 return inst->channel_id();
398 case HloOpcode::kAllReduce:
399 return inst->channel_id();
400 default:
401 return absl::nullopt;
402 }
403 };
404
405 // When adding a predecessor to the dfs_stack, we need to also add its
406 // associated channel dependencies.
407 const auto add_dfs_stack = [&](HloInstruction* inst) {
408 auto channel_id = get_channel_id(inst);
409 if (channel_id && channel_dependency_group.count(*channel_id)) {
410 auto it = channel_dependency_group.find(*channel_id);
411 for (HloInstruction* cinst : it->second) {
412 dfs_stack.emplace_back(cinst);
413 }
414 } else {
415 dfs_stack.emplace_back(inst);
416 }
417 };
418
419 const auto add_predecessors = [&](HloInstruction* inst) {
420 // Add the operands to the stack in reverse order so the first operand is
421 // processed first. This will produce a more natural ordering and a nicer
422 // result for things like HLO stringification.
423 const auto& operands = inst->operands();
424 for (int64 i = operands.size() - 1; i >= 0; --i) {
425 add_dfs_stack(operands[i]);
426 }
427
428 for (HloInstruction* op : inst->control_predecessors()) {
429 add_dfs_stack(op);
430 }
431 };
432
433 // If the current instruction is a channel instruction, add the dependencies
434 // from all associated instructions of the channel.
435 auto channel_id = get_channel_id(current);
436 if (channel_id && channel_dependency_group.count(*channel_id)) {
437 auto it = channel_dependency_group.find(*channel_id);
438 for (HloInstruction* cinst : it->second) {
439 add_predecessors(cinst);
440 }
441 } else {
442 add_predecessors(current);
443 }
444 }
445 }
446
447 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const448 HloComputation::ComputeChannelDependencies() const {
449 ChannelDependencyGroup channel_dependency_group;
450 for (const auto& instruction : instructions_) {
451 switch (instruction->opcode()) {
452 case HloOpcode::kSend:
453 case HloOpcode::kRecvDone:
454 case HloOpcode::kAllReduce: {
455 auto channel_id = instruction->channel_id();
456 if (channel_id) {
457 channel_dependency_group[channel_id.value()].push_back(
458 instruction.get());
459 }
460 break;
461 }
462 default:
463 break;
464 }
465 }
466 return channel_dependency_group;
467 }
468
HasOnlyTraceUsers(const HloInstruction * instruction)469 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
470 return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
471 return user->opcode() == HloOpcode::kTrace;
472 });
473 }
474
MakeInstructionPostOrder() const475 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
476 auto channel_dependency_group = ComputeChannelDependencies();
477 std::vector<HloInstruction*> post_order;
478 post_order.reserve(instruction_count());
479 std::vector<HloInstruction*> trace_instructions;
480 absl::flat_hash_map<HloInstruction*, VisitState> visited;
481 visited.reserve(instruction_count());
482 for (auto& instruction : instructions_) {
483 if (instruction->opcode() == HloOpcode::kTrace) {
484 // Trace instructions aren't handled by the DFS visitor. Add trace
485 // instructions to the post order at the end (necessarily they have no
486 // users).
487 trace_instructions.push_back(instruction.get());
488 } else if (HasOnlyTraceUsers(instruction.get())) {
489 ComputeInstructionPostOrder(channel_dependency_group, &post_order,
490 instruction.get(), &visited);
491 }
492 }
493 post_order.insert(post_order.end(), trace_instructions.begin(),
494 trace_instructions.end());
495 CHECK_EQ(instructions_.size(), post_order.size())
496 << "number of instructions does not match post order size";
497 return post_order;
498 }
499
MakeEmbeddedComputationsList() const500 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
501 const {
502 absl::flat_hash_set<HloComputation*> visited;
503 std::vector<HloComputation*> post_order;
504
505 // To avoid special handling of this computation, cast away const of
506 // 'this'. 'this' is immediately removed from the post order after
507 // construction.
508 //
509 // TODO(b/78350259): This violates const-correctness, since while the original
510 // computation is not returned, we still retrieve non-const computations from
511 // a const one. Consider also avoiding const for HloComputation, or review XLA
512 // for const-correctness of non-HloInstruction* types like this.
513 ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
514 &post_order);
515
516 // We don't want to include this computation in the post order.
517 CHECK_EQ(this, post_order.back());
518 post_order.pop_back();
519
520 return post_order;
521 }
522
ToString(const HloPrintOptions & options) const523 string HloComputation::ToString(const HloPrintOptions& options) const {
524 return ToString(options, MakeInstructionPostOrder());
525 }
526
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const527 string HloComputation::ToString(
528 const HloPrintOptions& options,
529 absl::Span<const HloInstruction* const> instruction_order) const {
530 CHECK_EQ(instruction_order.size(), instruction_count());
531
532 const string tab(2 * options.indent_amount(), ' ');
533
534 std::ostringstream s;
535 s << tab;
536
537 if (!options.is_in_nested_computation()) {
538 if (options.print_percent()) {
539 s << "%";
540 }
541 s << PrintName(name(), options.print_ids()) << " ";
542 }
543
544 if (options.print_program_shape()) {
545 s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
546 << " ";
547 }
548 s << "{\n";
549
550 // There are instructions which are required to be printed. Additionally, we
551 // print some instructions before and after required ones. The resulting
552 // output has the following format.
553 //
554 // computation {
555 // ...
556 // additional_instructions
557 // required_instructions
558 // additional_instructions
559 // ...
560 // additional_instructions
561 // required_instructions
562 // additional_instructions
563 // ...
564 // }
565 std::set<int> instructions_to_print;
566 {
567 // Find all the instructions that should be printed.
568 auto add_instruction = [&instructions_to_print,
569 &instruction_order](int index) {
570 if (index < 0 || index >= instruction_order.size()) {
571 return;
572 }
573 instructions_to_print.insert(index);
574 };
575
576 auto add_instructions_arround = [&add_instruction, &options](int index) {
577 for (int i = index - options.leading_and_trailing_instructions_number();
578 i <= index + options.leading_and_trailing_instructions_number();
579 ++i) {
580 add_instruction(i);
581 }
582 };
583
584 for (int i = 0; i < instruction_order.size(); ++i) {
585 const HloInstruction* instruction = instruction_order[i];
586 CHECK_EQ(this, instruction->parent());
587 if (options.print_instruction(instruction)) {
588 add_instructions_arround(i);
589 }
590 }
591 }
592
593 {
594 // Print the instructions in this computation.
595 HloPrintOptions new_options = options;
596 new_options.set_indent_amount(options.indent_amount() + 1)
597 .set_is_in_nested_computation(true);
598
599 const string new_tab(2 * new_options.indent_amount(), ' ');
600
601 CanonicalNameMap name_map;
602
603 bool print_prev = true;
604 for (int index = 0; index < instruction_order.size(); ++index) {
605 const HloInstruction* instruction = instruction_order[index];
606 if (instructions_to_print.find(index) != instructions_to_print.end()) {
607 s << new_options.format_instruction(
608 instruction,
609 instruction->ToStringWithCanonicalNameMap(new_options,
610 &name_map),
611 new_options.indent_amount(), instruction == root_instruction_)
612 << "\n";
613 print_prev = true;
614 } else if (print_prev) {
615 s << new_tab << "...\n";
616 print_prev = false;
617 }
618 }
619 }
620
621 s << tab << "}";
622 return s.str();
623 }
624
ToProto() const625 HloComputationProto HloComputation::ToProto() const {
626 HloComputationProto proto;
627 CHECK(unique_id_ != -1)
628 << "This computation does not have a valid id. Please make sure the "
629 "computation is inside a module before dumping it.";
630 proto.set_id(unique_id_);
631 proto.set_name(name_);
632 for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
633 HloInstructionProto instruction_proto = instruction->ToProto();
634 proto.add_instructions()->Swap(&instruction_proto);
635 }
636 proto.set_root_id(root_instruction()->unique_id());
637 *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
638 return proto;
639 }
640
641 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)642 HloComputation::CreateFromProto(
643 const HloComputationProto& proto,
644 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
645 bool prohibit_empty_literal) {
646 absl::flat_hash_map<int64, HloInstruction*> instruction_map;
647 absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
648 std::vector<std::unique_ptr<HloInstruction>> instructions;
649 int64 parameter_count = 0;
650 for (const HloInstructionProto& instruction_proto : proto.instructions()) {
651 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
652 HloInstruction::CreateFromProto(
653 instruction_proto, instruction_map, computation_map,
654 prohibit_empty_literal));
655 if (instruction->opcode() == HloOpcode::kParameter) {
656 parameter_count++;
657 }
658 TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
659 instruction_map[instruction_proto.id()] = instruction.get();
660 to_proto_id[instruction.get()] = instruction_proto.id();
661 instructions.push_back(std::move(instruction));
662 }
663
664 TF_RET_CHECK(proto.root_id() != -1);
665 TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
666 HloInstruction* root = instruction_map.at(proto.root_id());
667
668 // Sort the instructions in the proto id's order.
669 absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
670 const std::unique_ptr<HloInstruction>& b) {
671 return to_proto_id[a.get()] < to_proto_id[b.get()];
672 });
673
674 TF_RETURN_IF_ERROR([&]() -> Status {
675 std::vector<bool> parameters_seen(parameter_count);
676 int parameters_seen_count = 0;
677 for (auto& instruction : instructions) {
678 if (instruction->opcode() == HloOpcode::kParameter) {
679 int64 param_no = instruction->parameter_number();
680 TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
681 << "Invalid parameter number. Expected [0, " << parameter_count
682 << "), got " << param_no;
683 TF_RET_CHECK(!parameters_seen[param_no])
684 << "Parameter number " << param_no
685 << " already allocated in this computation";
686 parameters_seen[param_no] = true;
687 parameters_seen_count++;
688 }
689 }
690 TF_RET_CHECK(parameters_seen_count == parameter_count)
691 << "Not all parameters in range [0, " << parameter_count
692 << ") were referenced";
693 return Status::OK();
694 }());
695
696 auto computation = absl::WrapUnique(
697 new HloComputation(proto.name(), parameter_count, &instructions, root,
698 /*fusion_instruction=*/nullptr));
699 computation->unique_id_ = proto.id();
700 return std::move(computation);
701 }
702
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)703 void HloComputation::FuseInstructionsInto(
704 absl::Span<HloInstruction* const> instructions_to_fuse,
705 HloInstruction* fusion_instruction) {
706 CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
707 HloInstruction* root = instructions_to_fuse.front();
708 TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
709 if (root == root_instruction()) {
710 set_root_instruction(fusion_instruction);
711 }
712 TF_CHECK_OK(RemoveInstruction(root));
713 for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
714 HloInstruction* instruction = instructions_to_fuse[i];
715 fusion_instruction->FuseInstruction(instruction);
716 if (instruction->user_count() == 0) {
717 TF_CHECK_OK(RemoveInstruction(instruction));
718 }
719 }
720 }
721
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)722 HloInstruction* HloComputation::CreateFusionInstruction(
723 absl::Span<HloInstruction* const> instructions_to_fuse,
724 HloInstruction::FusionKind fusion_kind) {
725 HloInstruction* root = instructions_to_fuse.front();
726 HloInstruction* fusion_instruction = AddInstruction(
727 HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
728 FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
729 return fusion_instruction;
730 }
731
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)732 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
733 HloInstruction* instruction, ShapeIndex* index,
734 const std::function<
735 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
736 HloComputation* computation)>& copy_leaf) {
737 if (instruction->shape().IsTuple()) {
738 std::vector<HloInstruction*> elements;
739 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
740 i++) {
741 HloInstruction* gte =
742 AddInstruction(HloInstruction::CreateGetTupleElement(
743 ShapeUtil::GetTupleElementShape(instruction->shape(), i),
744 instruction, i));
745
746 index->push_back(i);
747 TF_ASSIGN_OR_RETURN(HloInstruction * element,
748 DeepCopyHelper(gte, index, copy_leaf));
749 elements.push_back(element);
750 index->pop_back();
751 }
752 return AddInstruction(HloInstruction::CreateTuple(elements));
753 }
754 if (instruction->shape().IsToken()) {
755 // Tokens have no on-device representation and cannot be copied. Pass
756 // through transparently.
757 return instruction;
758 }
759
760 // Array shape.
761 TF_RET_CHECK(instruction->shape().IsArray());
762 return copy_leaf(instruction, *index, this);
763 }
764
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)765 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
766 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
767 ShapeTree<HloInstruction*>* copies_added) {
768 if (instruction->parent() != this) {
769 return FailedPrecondition(
770 "Can't deep copy instruction %s: instruction is not in computation %s",
771 instruction->name(), name());
772 }
773 if (indices_to_copy != nullptr &&
774 !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
775 return FailedPrecondition(
776 "Can't deep copy instruction %s: given shape tree of indices to copy "
777 "has incompatible shapes: %s vs. %s",
778 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
779 ShapeUtil::HumanString(indices_to_copy->shape()));
780 }
781
782 ShapeIndex index;
783 auto copy_leaf = [indices_to_copy, copies_added](
784 HloInstruction* leaf, const ShapeIndex& leaf_index,
785 HloComputation* computation) {
786 if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
787 HloInstruction* copy = computation->AddInstruction(
788 HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
789 if (copies_added != nullptr) {
790 *copies_added->mutable_element(leaf_index) = copy;
791 }
792 return copy;
793 }
794 // Elements which are not to be copied are passed through
795 // transparently.
796 return leaf;
797 };
798 return DeepCopyHelper(instruction, &index, copy_leaf);
799 }
800
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)801 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
802 HloInstruction* instruction,
803 const std::function<
804 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
805 HloComputation* computation)>& copy_leaf) {
806 if (instruction->parent() != this) {
807 return FailedPrecondition(
808 "Can't deep copy instruction %s: instruction is not in computation %s",
809 instruction->name(), name());
810 }
811 ShapeIndex index;
812 return DeepCopyHelper(instruction, &index, copy_leaf);
813 }
814
ComputeProgramShape(bool include_ids) const815 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
816 ProgramShape program_shape;
817
818 for (auto* param_instruction : param_instructions_) {
819 *program_shape.add_parameters() = param_instruction->shape();
820 *program_shape.add_parameter_names() =
821 PrintName(param_instruction->name(), include_ids);
822 }
823 *program_shape.mutable_result() = root_instruction_->shape();
824
825 return program_shape;
826 }
827
Equal(const HloComputation & other,bool is_layout_sensitive) const828 bool HloComputation::Equal(const HloComputation& other,
829 bool is_layout_sensitive) const {
830 if (this == &other) {
831 return true;
832 }
833 absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
834 visited;
835 std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
836
837 worklist.push_back({root_instruction(), other.root_instruction()});
838
839 while (!worklist.empty()) {
840 auto pair = worklist.back();
841 worklist.pop_back();
842
843 if (visited.contains(pair)) {
844 continue;
845 }
846 visited.emplace(pair);
847 // TODO(b/123082518): Avoid recursively invoking == because it may
848 // cause a stack overflow with deeply nested subcomputations.
849 bool identical_ignoring_operands = pair.first->Identical(
850 *pair.second,
851 [](const HloInstruction*, const HloInstruction*) { return true; },
852 [](const HloComputation* a, const HloComputation* b) {
853 return *a == *b;
854 },
855 is_layout_sensitive);
856 if (!identical_ignoring_operands) {
857 return false;
858 }
859 for (size_t i = 0; i < pair.first->operands().size(); ++i) {
860 worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
861 }
862 }
863 return true;
864 }
865
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)866 Status HloComputation::ReplaceWithNewInstruction(
867 HloInstruction* old_instruction,
868 std::unique_ptr<HloInstruction> new_instruction) {
869 return ReplaceInstruction(old_instruction,
870 AddInstruction(std::move(new_instruction)));
871 }
872
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)873 Status HloComputation::ReplaceWithNewEntryComputationParameter(
874 HloInstruction* old_instruction,
875 std::unique_ptr<HloInstruction> new_instruction) {
876 return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
877 std::move(new_instruction)));
878 }
879
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)880 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
881 HloInstruction* new_instruction) {
882 TF_RET_CHECK(
883 ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
884 << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
885 << ShapeUtil::HumanString(new_instruction->shape());
886
887 VLOG(10) << "transformed " << old_instruction->ToString() << " to "
888 << new_instruction->ToString();
889 // Try to add metadata for HLO instructions that are created to replace
890 // existing HLO instructions (e.g. during optimizations). The assumption is
891 // that the old instruction and the new instruction would perform the same
892 // function, and that they would be correlated to the same TF op. This might
893 // not always be correct since HLO optimizations can cross TF op boundaries.
894 // But still this seems to be better than nothing.
895 if (new_instruction->metadata().op_name().empty()) {
896 new_instruction->set_metadata(old_instruction->metadata());
897 }
898 if (new_instruction->frontend_attributes().map().empty()) {
899 new_instruction->set_frontend_attributes(
900 old_instruction->frontend_attributes());
901 }
902
903 // Like the metadata above, if the user didn't specify any sharding
904 // information on the new instruction we should copy the old sharding
905 // information (if any).
906 if (!new_instruction->has_sharding()) {
907 new_instruction->set_sharding(old_instruction->sharding_ptr());
908 }
909
910 TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
911 return RemoveInstructionAndUnusedOperands(old_instruction);
912 }
913
CollectUnreachableRoots() const914 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
915 std::vector<HloInstruction*> unreachable_roots;
916 for (auto* instruction : instructions()) {
917 if (instruction->user_count() == 0 &&
918 instruction->control_successors().empty() &&
919 instruction != root_instruction()) {
920 unreachable_roots.push_back(instruction);
921 }
922 }
923 VLOG(3) << "Unreachable roots:"
924 << absl::StrJoin(unreachable_roots, "\n\t",
925 [](string* out, const HloInstruction* hlo) {
926 absl::StrAppend(out, hlo->ToString());
927 });
928 return unreachable_roots;
929 }
930
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const931 Status HloComputation::AcceptWithOperandOrder(
932 DfsHloVisitor* visitor,
933 const HloInstruction::CompareFunction& operand_order) const {
934 // Visit unreachable roots. Beware that the visitor might delete the currently
935 // visited root, which would invalidate iterators if the unreachable roots
936 // weren't computed ahead of time.
937 for (HloInstruction* root : CollectUnreachableRoots()) {
938 TF_RETURN_IF_ERROR(
939 root->AcceptWithOperandOrder(visitor, operand_order,
940 /*call_finish_visit=*/false));
941 }
942 // Visit the computation root instruction last.
943 return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
944 /*call_finish_visit=*/true);
945 }
946
Clone(const string & suffix,HloCloneContext * context)947 std::unique_ptr<HloComputation> HloComputation::Clone(
948 const string& suffix, HloCloneContext* context) {
949 return CloneWithReplacements(
950 /*replacements=*/absl::flat_hash_map<const HloInstruction*,
951 std::unique_ptr<HloInstruction>>(),
952 /*extra_parameters=*/{}, context, suffix);
953 }
954
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)955 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
956 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
957 HloCloneContext* context, const string& suffix) {
958 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
959 replacements;
960 replacements.emplace(std::move(r1));
961 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
962 context, suffix);
963 }
964
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)965 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
966 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
967 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
968 HloCloneContext* context, const string& suffix) {
969 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
970 replacements;
971 replacements.emplace(std::move(r1));
972 replacements.emplace(std::move(r2));
973 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
974 context, suffix);
975 }
976
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)977 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
978 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
979 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
980 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
981 HloCloneContext* context, const string& suffix) {
982 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
983 replacements;
984 replacements.emplace(std::move(r1));
985 replacements.emplace(std::move(r2));
986 replacements.emplace(std::move(r3));
987 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
988 context, suffix);
989 }
990
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix)991 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
992 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
993 replacements,
994 absl::Span<const HloInstruction* const> extra_parameters,
995 HloCloneContext* context, const string& suffix) {
996 std::unique_ptr<HloCloneContext> context_ptr;
997 if (context == nullptr) {
998 context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
999 context = context_ptr.get();
1000 }
1001
1002 // Look up instr in the replacements map, and return either the replacement,
1003 // or instr, if the replacement isn't present.
1004 //
1005 // Note: This can return null, indicating that instr should not be present in
1006 // the new computation.
1007 auto replace = [&](HloInstruction* instr) {
1008 auto it = replacements.find(instr);
1009 if (it == replacements.end()) {
1010 return instr;
1011 }
1012 return it->second.get();
1013 };
1014
1015 VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1016
1017 // We want to do a postorder walk over [replace(i) for i in instructions_].
1018 // We can't reuse MakeInstructionPostOrder() for this, because that will
1019 // generate a postorder of plain instructions_, and our replacements may
1020 // change the postorder!
1021 //
1022 // The postorder we want here is simpler than what MakeInstructionPostOrder()
1023 // does -- we only care about operand dependencies -- so let's just do it
1024 // ourselves.
1025 std::vector<HloInstruction*> postorder;
1026 absl::flat_hash_map<HloInstruction*, VisitState> visited;
1027 for (const auto& instr : instructions_) {
1028 std::vector<HloInstruction*> dfs_stack;
1029 HloInstruction* new_instr = replace(instr.get());
1030 if (!new_instr) {
1031 continue;
1032 }
1033 dfs_stack.push_back(new_instr);
1034
1035 while (!dfs_stack.empty()) {
1036 auto* cur = dfs_stack.back();
1037 auto it = visited.find(cur);
1038 if (it != visited.end()) {
1039 dfs_stack.pop_back();
1040 if (it->second == kVisited) {
1041 continue;
1042 }
1043 CHECK_EQ(it->second, kVisiting);
1044 postorder.push_back(cur);
1045 it->second = kVisited;
1046 continue;
1047 }
1048
1049 visited.insert({cur, kVisiting});
1050 for (HloInstruction* operand : cur->operands()) {
1051 HloInstruction* new_operand = replace(operand);
1052 if (new_operand) {
1053 dfs_stack.emplace_back(new_operand);
1054 }
1055 }
1056 }
1057 }
1058
1059 std::vector<std::unique_ptr<HloInstruction>> instructions;
1060 // First add the extra parameters to 'instructions'.
1061 for (const auto& instr : extra_parameters) {
1062 CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1063 << "Only parameter instructions are allowed in 'extra_parameters'";
1064 instructions.emplace_back(instr->Clone());
1065 }
1066 for (auto instr : postorder) {
1067 std::vector<HloInstruction*> new_operands;
1068 for (auto operand : instr->operands()) {
1069 auto replaced_operand = replace(operand);
1070 CHECK_NE(replaced_operand, nullptr)
1071 << "replacements map tried to eliminate a used instruction "
1072 << operand->ToString() << ", used by " << instr->ToString();
1073 new_operands.push_back(context->GetInstruction(replaced_operand));
1074 }
1075 std::unique_ptr<HloInstruction> new_instr =
1076 instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1077 if (instr->opcode() == HloOpcode::kParameter &&
1078 instr->parameter_replicated_at_leaf_buffers().has_value()) {
1079 new_instr->set_parameter_replicated_at_leaf_buffers(
1080 instr->parameter_replicated_at_leaf_buffers().value());
1081 }
1082 instructions.push_back(std::move(new_instr));
1083 }
1084 Builder builder(name() + "." + suffix);
1085 for (auto& instr : instructions) {
1086 builder.AddInstruction(std::move(instr));
1087 }
1088 auto result = builder.Build(
1089 /*root_instruction=*/context->GetInstruction(
1090 replace(root_instruction())));
1091
1092 // Clone control dependencies.
1093 for (auto instr : postorder) {
1094 HloInstruction* new_instr = context->GetInstruction(instr);
1095 for (auto successor : instr->control_successors()) {
1096 auto replaced_successor = replace(successor);
1097 // successor may not have been remapped, because it might have been
1098 // removed by the replacements map.
1099 if (replaced_successor != nullptr) {
1100 TF_CHECK_OK(new_instr->AddControlDependencyTo(
1101 context->GetInstruction(replaced_successor)));
1102 }
1103 }
1104 }
1105 context->MapComputation(this, result.get());
1106 return result;
1107 }
1108
UniquifyName(NameUniquer * name_uniquer)1109 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1110 name_ = name_uniquer->GetUniqueName(name_);
1111 }
1112
GetInstructionWithName(absl::string_view name)1113 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1114 auto instructions_in_computation = instructions();
1115 auto it = absl::c_find_if(
1116 instructions_in_computation,
1117 [&](HloInstruction* instr) { return instr->name() == name; });
1118 return it == instructions_in_computation.end() ? nullptr : *it;
1119 }
1120
IsEntryComputation() const1121 bool HloComputation::IsEntryComputation() const {
1122 return parent()->entry_computation() == this;
1123 }
1124 } // namespace xla
1125