1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48
49 using absl::StrCat;
50
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52 HloInstruction* root_instruction) {
53 int parameter_count = 0;
54 for (auto& instruction : instructions_) {
55 if (instruction->opcode() == HloOpcode::kParameter) {
56 parameter_count++;
57 }
58 }
59 // If root_instruction is not specified use the last added instruction.
60 HloInstruction* root =
61 root_instruction ? root_instruction : last_added_instruction_;
62 CHECK_NE(nullptr, root);
63 return absl::WrapUnique(new HloComputation(
64 name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68 const string& name, int parameter_count,
69 std::vector<std::unique_ptr<HloInstruction>>* instructions,
70 HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71 : name_(NameUniquer::GetSanitizedName(name)),
72 unique_id_(-1),
73 root_instruction_(root_instruction),
74 fusion_instruction_(fusion_instruction) {
75 param_instructions_.resize(parameter_count, nullptr);
76 bool root_found = false;
77 for (auto& instruction : *instructions) {
78 if (instruction->opcode() == HloOpcode::kParameter) {
79 int64 param_no = instruction->parameter_number();
80 CHECK(param_no >= 0 && param_no < parameter_count)
81 << "\nERROR: invalid parameter number. Expected [0, "
82 << parameter_count << "), got " << param_no;
83 CHECK(param_instructions_[param_no] == nullptr)
84 << "\nERROR: parameter number " << param_no
85 << " already allocated in this computation";
86 param_instructions_[param_no] = instruction.get();
87 }
88 root_found |= instruction.get() == root_instruction_;
89 AddInstructionInternal(std::move(instruction));
90 }
91 CHECK(root_found)
92 << "\nERROR: root instruction is not present in computation.";
93 }
94
AddInstruction(std::unique_ptr<HloInstruction> instruction)95 HloInstruction* HloComputation::AddInstruction(
96 std::unique_ptr<HloInstruction> instruction) {
97 CHECK(instruction->opcode() != HloOpcode::kParameter)
98 << "Parameter instructions cannot be added to a computation after "
99 << "it has been built";
100 return AddInstructionInternal(std::move(instruction));
101 }
102
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)103 HloInstruction* HloComputation::AddInstructionInternal(
104 std::unique_ptr<HloInstruction> instruction) {
105 if (parent() != nullptr) {
106 instruction->UniquifyName(&parent()->instruction_name_uniquer());
107 instruction->SetUniqueId(parent()->NewUniqueInstructionId());
108 }
109 instruction->set_parent(this);
110 HloInstruction* pinst = instruction.get();
111 instruction_iterators_[pinst] =
112 instructions_.insert(instructions_.end(), std::move(instruction));
113 return pinst;
114 }
115
AddParameter(std::unique_ptr<HloInstruction> instruction)116 HloInstruction* HloComputation::AddParameter(
117 std::unique_ptr<HloInstruction> instruction) {
118 CHECK(instruction->opcode() == HloOpcode::kParameter);
119 CHECK(IsFusionComputation());
120 CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
121 instruction->set_parent(this);
122 param_instructions_.push_back(instruction.get());
123 AddInstructionInternal(std::move(instruction));
124 return instructions_.back().get();
125 }
126
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)127 HloInstruction* HloComputation::AddEntryComputationParameter(
128 std::unique_ptr<HloInstruction> instruction) {
129 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
130 CHECK_EQ(instruction->parameter_number(), num_parameters());
131 CHECK(parent()->entry_computation() == this);
132
133 HloModuleConfig config = parent()->config();
134 config.mutable_entry_computation_layout()->add_parameter_layout(
135 ShapeLayout(instruction->shape()));
136 parent()->set_config(config);
137
138 instruction->set_parent(this);
139 param_instructions_.push_back(instruction.get());
140 AddInstructionInternal(std::move(instruction));
141
142 return instructions_.back().get();
143 }
144
RemoveParameter(int64 param_no)145 Status HloComputation::RemoveParameter(int64 param_no) {
146 CHECK_GE(param_no, 0);
147 CHECK_LT(param_no, param_instructions_.size());
148 CHECK(IsFusionComputation());
149 HloInstruction* param_instruction = param_instructions_[param_no];
150 auto param_instruction_iterator = param_instructions_.begin() + param_no;
151 param_instructions_.erase(param_instruction_iterator);
152 // Throw removed fused parameter instruction away.
153 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
154
155 while (param_no < param_instructions_.size()) {
156 param_instruction = param_instructions_[param_no];
157 HloInstruction* new_instr =
158 AddInstructionInternal(HloInstruction::CreateParameter(
159 param_no, param_instruction->shape(), StrCat("param_", param_no)));
160 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
161 param_instructions_[param_no] = new_instr;
162 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
163 param_no++;
164 }
165
166 return Status::OK();
167 }
168
RemoveUnusedParameters()169 Status HloComputation::RemoveUnusedParameters() {
170 CHECK(IsFusionComputation());
171 int64 removed = 0;
172 for (int64 i = 0; i < param_instructions_.size(); ++i) {
173 HloInstruction* param_instruction = param_instructions_[i];
174 if (param_instruction->user_count() == 0 &&
175 param_instruction != root_instruction()) {
176 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
177 ++removed;
178 continue;
179 }
180
181 if (removed > 0) {
182 const int64 param_no = i - removed;
183 HloInstruction* new_instr = AddInstructionInternal(
184 HloInstruction::CreateParameter(param_no, param_instruction->shape(),
185 StrCat("param_", param_no)));
186 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
187 param_instructions_[param_no] = new_instr;
188 TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
189 }
190 }
191 param_instructions_.resize(param_instructions_.size() - removed);
192 return Status::OK();
193 }
194
IsRemovable(const HloInstruction * instruction)195 bool HloComputation::IsRemovable(const HloInstruction* instruction) {
196 // If the instruction has control predecessors or successors then we cannot
197 // remove the instruction without violating ordering constraints (added, for
198 // example, to avert interference due to buffer aliasing).
199 if (!instruction->control_predecessors().empty() ||
200 !instruction->control_successors().empty()) {
201 return false;
202 }
203
204 if (instruction->opcode() == HloOpcode::kParameter &&
205 !IsFusionComputation()) {
206 return false;
207 }
208
209 return true;
210 }
211
HasSideEffect() const212 bool HloComputation::HasSideEffect() const {
213 for (auto* instruction : instructions()) {
214 if (instruction->HasSideEffect()) {
215 return true;
216 }
217 }
218 return false;
219 }
220
RemoveInstructionAndUnusedOperands(HloInstruction * instruction)221 Status HloComputation::RemoveInstructionAndUnusedOperands(
222 HloInstruction* instruction) {
223 TF_RET_CHECK(root_instruction() != instruction);
224
225 TF_RET_CHECK(instruction->user_count() == 0);
226 TF_RET_CHECK(IsRemovable(instruction))
227 << "Cannot remove instruction: " << instruction->ToString();
228 absl::flat_hash_set<HloInstruction*> removed;
229 std::queue<HloInstruction*> worklist;
230 worklist.push(instruction);
231 while (!worklist.empty()) {
232 HloInstruction* item = worklist.front();
233 worklist.pop();
234
235 if (removed.contains(item) || item->user_count() != 0 ||
236 item == root_instruction() || !IsRemovable(item) ||
237 (item->HasSideEffect() && item != instruction)) {
238 continue;
239 }
240 for (int i = 0; i < item->operand_count(); ++i) {
241 worklist.push(item->mutable_operand(i));
242 }
243
244 TF_RETURN_IF_ERROR(RemoveInstruction(item));
245 removed.insert(item);
246 }
247 return Status::OK();
248 }
249
RemoveInstruction(HloInstruction * instruction)250 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
251 VLOG(2) << "Removing instruction " << instruction->name()
252 << " from computation " << name();
253 TF_RET_CHECK(IsRemovable(instruction))
254 << "cannot remove instruction: " << instruction->ToString();
255 TF_RET_CHECK(root_instruction() != instruction)
256 << "cannot remove root instruction " << instruction->name();
257 TF_RET_CHECK(instruction->user_count() == 0)
258 << "instruction " << instruction->name()
259 << " has users and cannot be removed";
260 TF_RET_CHECK(instruction->control_predecessors().empty())
261 << "instruction " << instruction->name()
262 << " has control predecessors and cannot be removed";
263 TF_RET_CHECK(instruction->control_successors().empty())
264 << "instruction " << instruction->name()
265 << " has control successors and cannot be removed";
266
267 auto inst_it = instruction_iterators_.find(instruction);
268 TF_RET_CHECK(inst_it != instruction_iterators_.end());
269 (*inst_it->second)->set_parent(nullptr);
270 instructions_.erase(inst_it->second);
271 instruction_iterators_.erase(inst_it);
272 return Status::OK();
273 }
274
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)275 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
276 bool accept_different_shape) {
277 // The shape of the root (ignoring layout) is an invariant of the computation
278 // for non-fusion cases.
279 if (!IsFusionComputation() && !accept_different_shape) {
280 CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
281 root_instruction_->shape()))
282 << new_root_instruction->shape() << " is incompatible with "
283 << root_instruction_->shape();
284 }
285 bool root_found = false;
286 for (auto& instruction : instructions_) {
287 if (new_root_instruction == instruction.get()) {
288 root_found = true;
289 break;
290 }
291 }
292 DCHECK(root_found);
293
294 root_instruction_ = new_root_instruction;
295 }
296
297 namespace {
298
299 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)300 void ComputeComputationPostOrder(HloComputation* computation,
301 absl::flat_hash_set<HloComputation*>* visited,
302 std::vector<HloComputation*>* post_order) {
303 if (visited->insert(computation).second) {
304 for (auto* instruction : computation->instructions()) {
305 for (HloComputation* called_computation :
306 instruction->called_computations()) {
307 ComputeComputationPostOrder(called_computation, visited, post_order);
308 }
309 }
310 post_order->push_back(computation);
311 }
312 }
313
314 } // namespace
315
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const316 void HloComputation::ComputeInstructionPostOrder(
317 const HloComputation::ChannelDependencyGroup& channel_dependency_group,
318 std::vector<HloInstruction*>* post_order, HloInstruction* root,
319 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
320 std::vector<HloInstruction*> dfs_stack;
321 dfs_stack.push_back(root);
322 while (!dfs_stack.empty()) {
323 const auto current = dfs_stack.back();
324 auto it = visited->find(current);
325 if (it != visited->end()) {
326 if (it->second == kVisited) {
327 // Already visited.
328 dfs_stack.pop_back();
329 continue;
330 }
331 // Visit this node.
332 CHECK_EQ(kVisiting, it->second);
333 dfs_stack.pop_back();
334 post_order->push_back(current);
335 it->second = kVisited;
336 continue;
337 }
338
339 visited->insert({current, kVisiting});
340
341 const auto get_channel_id =
342 [](HloInstruction* inst) -> absl::optional<int64> {
343 switch (inst->opcode()) {
344 case HloOpcode::kRecvDone:
345 return inst->channel_id();
346 case HloOpcode::kAllReduce:
347 return inst->all_reduce_id();
348 default:
349 return absl::nullopt;
350 }
351 };
352
353 // When adding a predecessor to the dfs_stack, we need to also add its
354 // associated channel dependencies.
355 const auto add_dfs_stack = [&](HloInstruction* inst) {
356 auto channel_id = get_channel_id(inst);
357 if (channel_id && channel_dependency_group.count(*channel_id)) {
358 auto it = channel_dependency_group.find(*channel_id);
359 for (HloInstruction* cinst : it->second) {
360 dfs_stack.emplace_back(cinst);
361 }
362 } else {
363 dfs_stack.emplace_back(inst);
364 }
365 };
366
367 const auto add_predecessors = [&](HloInstruction* inst) {
368 // Add the operands to the stack in reverse order so the first operand is
369 // processed first. This will produce a more natural ordering and a nicer
370 // result for things like HLO stringification.
371 const auto& operands = inst->operands();
372 for (int64 i = operands.size() - 1; i >= 0; --i) {
373 add_dfs_stack(operands[i]);
374 }
375
376 for (HloInstruction* op : inst->control_predecessors()) {
377 add_dfs_stack(op);
378 }
379 };
380
381 // If the current instruction is a channel instruction, add the dependencies
382 // from all associated instructions of the channel.
383 auto channel_id = get_channel_id(current);
384 if (channel_id && channel_dependency_group.count(*channel_id)) {
385 auto it = channel_dependency_group.find(*channel_id);
386 for (HloInstruction* cinst : it->second) {
387 add_predecessors(cinst);
388 }
389 } else {
390 add_predecessors(current);
391 }
392 }
393 }
394
395 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const396 HloComputation::ComputeChannelDependencies() const {
397 ChannelDependencyGroup channel_dependency_group;
398 for (const auto& instruction : instructions_) {
399 switch (instruction->opcode()) {
400 case HloOpcode::kSend:
401 case HloOpcode::kRecvDone:
402 channel_dependency_group[instruction->channel_id()].push_back(
403 instruction.get());
404 break;
405 case HloOpcode::kAllReduce: {
406 auto all_reduce_id = instruction->all_reduce_id();
407 if (all_reduce_id) {
408 channel_dependency_group[all_reduce_id.value()].push_back(
409 instruction.get());
410 }
411 break;
412 }
413 default:
414 break;
415 }
416 }
417 return channel_dependency_group;
418 }
419
MakeInstructionPostOrder() const420 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
421 auto channel_dependency_group = ComputeChannelDependencies();
422 std::vector<HloInstruction*> post_order;
423 post_order.reserve(instruction_count());
424 std::vector<HloInstruction*> trace_instructions;
425 absl::flat_hash_map<HloInstruction*, VisitState> visited;
426 visited.reserve(instruction_count());
427 for (auto& instruction : instructions_) {
428 if (instruction->opcode() == HloOpcode::kTrace) {
429 // Trace instructions aren't handled by the DFS visitor. Add trace
430 // instructions to the post order at the end (necessarily they have no
431 // users).
432 trace_instructions.push_back(instruction.get());
433 } else if (instruction->users().empty()) {
434 ComputeInstructionPostOrder(channel_dependency_group, &post_order,
435 instruction.get(), &visited);
436 }
437 }
438 post_order.insert(post_order.end(), trace_instructions.begin(),
439 trace_instructions.end());
440 CHECK_EQ(instructions_.size(), post_order.size())
441 << "number of instructions does not match post order size";
442 return post_order;
443 }
444
MakeEmbeddedComputationsList() const445 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
446 const {
447 absl::flat_hash_set<HloComputation*> visited;
448 std::vector<HloComputation*> post_order;
449
450 // To avoid special handling of this computation, cast away const of
451 // 'this'. 'this' is immediately removed from the post order after
452 // construction.
453 //
454 // TODO(b/78350259): This violates const-correctness, since while the original
455 // computation is not returned, we still retrieve non-const computations from
456 // a const one. Consider also avoiding const for HloComputation, or review XLA
457 // for const-correctness of non-HloInstruction* types like this.
458 ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
459 &post_order);
460
461 // We don't want to include this computation in the post order.
462 CHECK_EQ(this, post_order.back());
463 post_order.pop_back();
464
465 return post_order;
466 }
467
ToString(const HloPrintOptions & options) const468 string HloComputation::ToString(const HloPrintOptions& options) const {
469 return ToString(options, MakeInstructionPostOrder());
470 }
471
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const472 string HloComputation::ToString(
473 const HloPrintOptions& options,
474 absl::Span<const HloInstruction* const> instruction_order) const {
475 CHECK_EQ(instruction_order.size(), instruction_count());
476
477 std::ostringstream s;
478 for (int i = 0; i < options.indent_amount(); i++) {
479 s << " ";
480 }
481
482 if (!options.is_in_nested_computation()) {
483 if (options.print_percent()) {
484 s << "%";
485 }
486 s << name() << " ";
487 }
488
489 if (options.print_program_shape()) {
490 s << ShapeUtil::HumanString(ComputeProgramShape()) << " ";
491 }
492 s << "{\n";
493 {
494 // Print the instructions in this computation.
495 HloPrintOptions new_options = options;
496 new_options.set_indent_amount(options.indent_amount() + 1)
497 .set_is_in_nested_computation(true);
498 CanonicalNameMap name_map;
499 for (const HloInstruction* instruction : instruction_order) {
500 CHECK_EQ(this, instruction->parent());
501
502 for (int i = 0; i < new_options.indent_amount(); i++) {
503 s << " ";
504 }
505 s << (instruction == root_instruction_ ? "ROOT " : "")
506 << instruction->ToStringWithCanonicalNameMap(new_options, &name_map)
507 << "\n";
508 }
509 }
510
511 for (int i = 0; i < options.indent_amount(); i++) {
512 s << " ";
513 }
514 s << "}";
515 return s.str();
516 }
517
ToProto() const518 HloComputationProto HloComputation::ToProto() const {
519 HloComputationProto proto;
520 CHECK(unique_id_ != -1)
521 << "This computation does not have a valid id. Please make sure the "
522 "computation is inside a module before dumping it.";
523 proto.set_id(unique_id_);
524 proto.set_name(name_);
525 for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
526 HloInstructionProto instruction_proto = instruction->ToProto();
527 proto.add_instructions()->Swap(&instruction_proto);
528 }
529 proto.set_root_id(root_instruction()->unique_id());
530 *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
531 return proto;
532 }
533
534 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map)535 HloComputation::CreateFromProto(
536 const HloComputationProto& proto,
537 const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
538 absl::flat_hash_map<int64, HloInstruction*> instruction_map;
539 absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
540 std::vector<std::unique_ptr<HloInstruction>> instructions;
541 int64 parameter_count = 0;
542 for (const HloInstructionProto& instruction_proto : proto.instructions()) {
543 TF_ASSIGN_OR_RETURN(
544 std::unique_ptr<HloInstruction> instruction,
545 HloInstruction::CreateFromProto(instruction_proto, instruction_map,
546 computation_map));
547 if (instruction->opcode() == HloOpcode::kParameter) {
548 parameter_count++;
549 }
550 TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
551 instruction_map[instruction_proto.id()] = instruction.get();
552 to_proto_id[instruction.get()] = instruction_proto.id();
553 instructions.push_back(std::move(instruction));
554 }
555
556 TF_RET_CHECK(proto.root_id() != -1);
557 TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
558 HloInstruction* root = instruction_map.at(proto.root_id());
559
560 // Sort the instructions in the proto id's order.
561 absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
562 const std::unique_ptr<HloInstruction>& b) {
563 return to_proto_id[a.get()] < to_proto_id[b.get()];
564 });
565
566 TF_RETURN_IF_ERROR([&]() -> Status {
567 std::vector<bool> parameters_seen(parameter_count);
568 int parameters_seen_count = 0;
569 for (auto& instruction : instructions) {
570 if (instruction->opcode() == HloOpcode::kParameter) {
571 int64 param_no = instruction->parameter_number();
572 TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
573 << "Invalid parameter number. Expected [0, " << parameter_count
574 << "), got " << param_no;
575 TF_RET_CHECK(!parameters_seen[param_no])
576 << "Parameter number " << param_no
577 << " already allocated in this computation";
578 parameters_seen[param_no] = true;
579 parameters_seen_count++;
580 }
581 }
582 TF_RET_CHECK(parameters_seen_count == parameter_count)
583 << "Not all parameters in range [0, " << parameter_count
584 << ") were referenced";
585 return Status::OK();
586 }());
587
588 auto computation = absl::WrapUnique(
589 new HloComputation(proto.name(), parameter_count, &instructions, root,
590 /*fusion_instruction=*/nullptr));
591 computation->unique_id_ = proto.id();
592 return std::move(computation);
593 }
594
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)595 void HloComputation::FuseInstructionsInto(
596 absl::Span<HloInstruction* const> instructions_to_fuse,
597 HloInstruction* fusion_instruction) {
598 CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
599 HloInstruction* root = instructions_to_fuse.front();
600 TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
601 if (root == root_instruction()) {
602 set_root_instruction(fusion_instruction);
603 }
604 TF_CHECK_OK(RemoveInstruction(root));
605 for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
606 HloInstruction* instruction = instructions_to_fuse[i];
607 fusion_instruction->FuseInstruction(instruction);
608 if (instruction->user_count() == 0) {
609 TF_CHECK_OK(RemoveInstruction(instruction));
610 }
611 }
612 }
613
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)614 HloInstruction* HloComputation::CreateFusionInstruction(
615 absl::Span<HloInstruction* const> instructions_to_fuse,
616 HloInstruction::FusionKind fusion_kind) {
617 HloInstruction* root = instructions_to_fuse.front();
618 HloInstruction* fusion_instruction = AddInstruction(
619 HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
620 FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
621 return fusion_instruction;
622 }
623
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)624 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
625 HloInstruction* instruction, ShapeIndex* index,
626 const std::function<
627 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
628 HloComputation* computation)>& copy_leaf) {
629 if (instruction->shape().IsTuple()) {
630 std::vector<HloInstruction*> elements;
631 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
632 i++) {
633 HloInstruction* gte =
634 AddInstruction(HloInstruction::CreateGetTupleElement(
635 ShapeUtil::GetTupleElementShape(instruction->shape(), i),
636 instruction, i));
637
638 index->push_back(i);
639 TF_ASSIGN_OR_RETURN(HloInstruction * element,
640 DeepCopyHelper(gte, index, copy_leaf));
641 elements.push_back(element);
642 index->pop_back();
643 }
644 return AddInstruction(HloInstruction::CreateTuple(elements));
645 }
646 if (instruction->shape().IsToken()) {
647 // Tokens have no on-device representation and cannot be copied. Pass
648 // through transparently.
649 return instruction;
650 }
651
652 // Array shape.
653 TF_RET_CHECK(instruction->shape().IsArray());
654 return copy_leaf(instruction, *index, this);
655 }
656
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)657 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
658 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
659 ShapeTree<HloInstruction*>* copies_added) {
660 if (instruction->parent() != this) {
661 return FailedPrecondition(
662 "Can't deep copy instruction %s: instruction is not in computation %s",
663 instruction->name(), name());
664 }
665 if (indices_to_copy != nullptr &&
666 !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
667 return FailedPrecondition(
668 "Can't deep copy instruction %s: given shape tree of indices to copy "
669 "has incompatible shapes: %s vs. %s",
670 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
671 ShapeUtil::HumanString(indices_to_copy->shape()));
672 }
673
674 ShapeIndex index;
675 auto copy_leaf = [indices_to_copy, copies_added](
676 HloInstruction* leaf, const ShapeIndex& leaf_index,
677 HloComputation* computation) {
678 if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
679 HloInstruction* copy = computation->AddInstruction(
680 HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
681 if (copies_added != nullptr) {
682 *copies_added->mutable_element(leaf_index) = copy;
683 }
684 return copy;
685 }
686 // Elements which are not to be copied are passed through
687 // transparently.
688 return leaf;
689 };
690 return DeepCopyHelper(instruction, &index, copy_leaf);
691 }
692
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)693 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
694 HloInstruction* instruction,
695 const std::function<
696 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
697 HloComputation* computation)>& copy_leaf) {
698 if (instruction->parent() != this) {
699 return FailedPrecondition(
700 "Can't deep copy instruction %s: instruction is not in computation %s",
701 instruction->name(), name());
702 }
703 ShapeIndex index;
704 return DeepCopyHelper(instruction, &index, copy_leaf);
705 }
706
ComputeProgramShape() const707 ProgramShape HloComputation::ComputeProgramShape() const {
708 ProgramShape program_shape;
709
710 for (auto* param_instruction : param_instructions_) {
711 *program_shape.add_parameters() = param_instruction->shape();
712 *program_shape.add_parameter_names() = param_instruction->name();
713 }
714 *program_shape.mutable_result() = root_instruction_->shape();
715
716 return program_shape;
717 }
718
operator ==(const HloComputation & other) const719 bool HloComputation::operator==(const HloComputation& other) const {
720 if (this == &other) {
721 return true;
722 }
723 absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
724 visited;
725 std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
726
727 worklist.push_back({root_instruction(), other.root_instruction()});
728
729 while (!worklist.empty()) {
730 auto pair = worklist.back();
731 worklist.pop_back();
732
733 if (visited.contains(pair)) {
734 continue;
735 }
736 visited.emplace(pair);
737 // TODO(b/123082518): Avoid recursively invoking == becasue it may
738 // cause a stack overflow with deeply nested subcomputations.
739 bool identical_ignoring_operands = pair.first->Identical(
740 *pair.second,
741 [](const HloInstruction*, const HloInstruction*) { return true; },
742 [](const HloComputation* a, const HloComputation* b) {
743 return *a == *b;
744 });
745 if (!identical_ignoring_operands) {
746 return false;
747 }
748 for (size_t i = 0; i < pair.first->operands().size(); ++i) {
749 worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
750 }
751 }
752 return true;
753 }
754
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)755 Status HloComputation::ReplaceWithNewInstruction(
756 HloInstruction* old_instruction,
757 std::unique_ptr<HloInstruction> new_instruction) {
758 return ReplaceInstruction(old_instruction,
759 AddInstruction(std::move(new_instruction)));
760 }
761
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)762 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
763 HloInstruction* new_instruction) {
764 TF_RET_CHECK(
765 ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
766 << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
767 << ShapeUtil::HumanString(new_instruction->shape());
768
769 VLOG(10) << "transformed " << old_instruction->ToString() << " to "
770 << new_instruction->ToString();
771 // Try to add metadata for HLO instructions that are created to replace
772 // existing HLO instructions (e.g. during optimizations). The assumption is
773 // that the old instruction and the new instruction would perform the same
774 // function, and that they would be correlated to the same TF op. This might
775 // not always be correct since HLO optimizations can cross TF op boundaries.
776 // But still this seems to be better than nothing.
777 if (new_instruction->metadata().op_name().empty()) {
778 new_instruction->set_metadata(old_instruction->metadata());
779 }
780 TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
781 return RemoveInstructionAndUnusedOperands(old_instruction);
782 }
783
CollectUnreachableRoots() const784 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
785 std::vector<HloInstruction*> unreachable_roots;
786 for (auto* instruction : instructions()) {
787 if (instruction->user_count() == 0 &&
788 instruction->control_successors().empty() &&
789 instruction != root_instruction()) {
790 unreachable_roots.push_back(instruction);
791 }
792 }
793 VLOG(3) << "Unreachable roots:"
794 << absl::StrJoin(unreachable_roots, "\n\t",
795 [](string* out, const HloInstruction* hlo) {
796 absl::StrAppend(out, hlo->ToString());
797 });
798 return unreachable_roots;
799 }
800
801 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor) const802 Status HloComputation::Accept(
803 DfsHloVisitorBase<HloInstructionPtr>* visitor) const {
804 // Visit unreachable roots. Beware that the visitor might delete the currently
805 // visited root, which would invalidate iterators if the unreachable roots
806 // weren't computed ahead of time.
807 for (HloInstruction* root : CollectUnreachableRoots()) {
808 VLOG(3) << "Traversing unreachable root: " << root->ToString();
809 // Call FinishVisit only at the end.
810 TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false));
811 }
812 // Visit the computation root instruction last.
813 return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
814 }
815
816 // Explicit instantiations.
817 template Status HloComputation::Accept(DfsHloVisitor* visitor) const;
818 template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const;
819
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const820 Status HloComputation::AcceptWithOperandOrder(
821 DfsHloVisitor* visitor,
822 const HloInstruction::CompareFunction& operand_order) const {
823 // Visit unreachable roots. Beware that the visitor might delete the currently
824 // visited root, which would invalidate iterators if the unreachable roots
825 // weren't computed ahead of time.
826 for (HloInstruction* root : CollectUnreachableRoots()) {
827 TF_RETURN_IF_ERROR(
828 root->AcceptWithOperandOrder(visitor, operand_order,
829 /*call_finish_visit=*/false));
830 }
831 // Visit the computation root instruction last.
832 return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
833 /*call_finish_visit=*/true);
834 }
835
836 template <typename HloInstructionPtr>
AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr> * visitor,absl::Span<HloInstruction * const> order) const837 Status HloComputation::AcceptOrdered(
838 DfsHloVisitorBase<HloInstructionPtr>* visitor,
839 absl::Span<HloInstruction* const> order) const {
840 VLOG(3) << "Accepting visitor with order.";
841 for (HloInstruction* root : CollectUnreachableRoots()) {
842 TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
843 }
844 TF_RET_CHECK(order.size() == instruction_count());
845 absl::flat_hash_set<const HloInstruction*> visited;
846 for (const HloInstruction* instruction : order) {
847 VLOG(3) << "Visiting ordered: " << instruction->ToString();
848 TF_RET_CHECK(instruction_iterators_.contains(instruction))
849 << "Instruction " << instruction->name() << " is not in computation "
850 << name();
851 TF_RET_CHECK(!visited.contains(instruction))
852 << "Instruction " << instruction->name()
853 << " appears more than once in order";
854 HloInstruction* mutable_instruction =
855 const_cast<HloInstruction*>(instruction);
856 TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction));
857 TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor));
858 visitor->SetVisited(*mutable_instruction);
859 TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction));
860 visited.insert(instruction);
861 }
862 TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction()));
863 return Status::OK();
864 }
865
866 // Explicit instantiations.
867 template Status HloComputation::AcceptOrdered(
868 DfsHloVisitor*, absl::Span<HloInstruction* const>) const;
869 template Status HloComputation::AcceptOrdered(
870 ConstDfsHloVisitor*, absl::Span<HloInstruction* const>) const;
871
Accept(const std::function<Status (HloInstruction *)> & visitor_func)872 Status HloComputation::Accept(
873 const std::function<Status(HloInstruction*)>& visitor_func) {
874 FunctionVisitor visitor(visitor_func);
875 return this->Accept(&visitor);
876 }
877
Accept(const std::function<Status (const HloInstruction *)> & visitor_func) const878 Status HloComputation::Accept(
879 const std::function<Status(const HloInstruction*)>& visitor_func) const {
880 ConstFunctionVisitor visitor(visitor_func);
881 return this->Accept(&visitor);
882 }
883
Clone(const string & suffix,HloCloneContext * context)884 std::unique_ptr<HloComputation> HloComputation::Clone(
885 const string& suffix, HloCloneContext* context) {
886 return CloneWithReplacements(
887 /*replacements=*/absl::flat_hash_map<const HloInstruction*,
888 std::unique_ptr<HloInstruction>>(),
889 /*extra_parameters=*/{}, context, suffix);
890 }
891
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)892 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
893 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
894 HloCloneContext* context, const string& suffix) {
895 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
896 replacements;
897 replacements.emplace(std::move(r1));
898 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
899 context, suffix);
900 }
901
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)902 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
903 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
904 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
905 HloCloneContext* context, const string& suffix) {
906 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
907 replacements;
908 replacements.emplace(std::move(r1));
909 replacements.emplace(std::move(r2));
910 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
911 context, suffix);
912 }
913
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)914 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
915 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
916 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
917 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
918 HloCloneContext* context, const string& suffix) {
919 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
920 replacements;
921 replacements.emplace(std::move(r1));
922 replacements.emplace(std::move(r2));
923 replacements.emplace(std::move(r3));
924 return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
925 context, suffix);
926 }
927
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix)928 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
929 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
930 replacements,
931 absl::Span<const HloInstruction* const> extra_parameters,
932 HloCloneContext* context, const string& suffix) {
933 std::unique_ptr<HloCloneContext> context_ptr;
934 if (context == nullptr) {
935 context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
936 context = context_ptr.get();
937 }
938
939 // Look up instr in the replacements map, and return either the replacement,
940 // or instr, if the replacement isn't present.
941 //
942 // Note: This can return null, indicating that instr should not be present in
943 // the new computation.
944 auto replace = [&](HloInstruction* instr) {
945 auto it = replacements.find(instr);
946 if (it == replacements.end()) {
947 return instr;
948 }
949 return it->second.get();
950 };
951
952 VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
953
954 // We want to do a postorder walk over [replace(i) for i in instructions_].
955 // We can't reuse MakeInstructionPostOrder() for this, because that will
956 // generate a postorder of plain instructions_, and our replacements may
957 // change the postorder!
958 //
959 // The postorder we want here is simpler than what MakeInstructionPostOrder()
960 // does -- we only care about operand dependencies -- so let's just do it
961 // ourselves.
962 std::vector<HloInstruction*> postorder;
963 absl::flat_hash_map<HloInstruction*, VisitState> visited;
964 for (const auto& instr : instructions_) {
965 std::vector<HloInstruction*> dfs_stack;
966 HloInstruction* new_instr = replace(instr.get());
967 if (!new_instr) {
968 continue;
969 }
970 dfs_stack.push_back(new_instr);
971
972 while (!dfs_stack.empty()) {
973 auto* cur = dfs_stack.back();
974 auto it = visited.find(cur);
975 if (it != visited.end()) {
976 dfs_stack.pop_back();
977 if (it->second == kVisited) {
978 continue;
979 }
980 CHECK_EQ(it->second, kVisiting);
981 postorder.push_back(cur);
982 it->second = kVisited;
983 continue;
984 }
985
986 visited.insert({cur, kVisiting});
987 for (HloInstruction* operand : cur->operands()) {
988 HloInstruction* new_operand = replace(operand);
989 if (new_operand) {
990 dfs_stack.emplace_back(new_operand);
991 }
992 }
993 }
994 }
995
996 std::vector<std::unique_ptr<HloInstruction>> instructions;
997 // First add the extra parameters to 'instructions'.
998 for (const auto& instr : extra_parameters) {
999 CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1000 << "Only parameter instructions are allowed in 'extra_parameters'";
1001 instructions.emplace_back(instr->Clone());
1002 }
1003 for (auto instr : postorder) {
1004 std::vector<HloInstruction*> new_operands;
1005 for (auto operand : instr->operands()) {
1006 auto replaced_operand = replace(operand);
1007 CHECK_NE(replaced_operand, nullptr)
1008 << "replacements map tried to eliminate a used instruction "
1009 << operand->ToString() << ", used by " << instr->ToString();
1010 new_operands.push_back(context->GetInstruction(replaced_operand));
1011 }
1012 instructions.push_back(
1013 instr->CloneWithNewOperands(instr->shape(), new_operands, context));
1014 }
1015 Builder builder(name() + "." + suffix);
1016 for (auto& instr : instructions) {
1017 builder.AddInstruction(std::move(instr));
1018 }
1019 auto result = builder.Build(
1020 /*root_instruction=*/context->GetInstruction(
1021 replace(root_instruction())));
1022
1023 // Clone control dependencies.
1024 for (auto instr : postorder) {
1025 HloInstruction* new_instr = context->GetInstruction(instr);
1026 for (auto successor : instr->control_successors()) {
1027 auto replaced_successor = replace(successor);
1028 // successor may not have been remapped, because it might have been
1029 // removed by the replacements map.
1030 if (replaced_successor != nullptr) {
1031 TF_CHECK_OK(new_instr->AddControlDependencyTo(
1032 context->GetInstruction(replaced_successor)));
1033 }
1034 }
1035 }
1036 context->MapComputation(this, result.get());
1037 return result;
1038 }
1039
UniquifyName(NameUniquer * name_uniquer)1040 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1041 name_ = name_uniquer->GetUniqueName(name_);
1042 }
1043
GetInstructionWithName(absl::string_view name)1044 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1045 auto instructions_in_computation = instructions();
1046 auto it = absl::c_find_if(
1047 instructions_in_computation,
1048 [&](HloInstruction* instr) { return instr->name() == name; });
1049 return it == instructions_in_computation.end() ? nullptr : *it;
1050 }
1051
1052 } // namespace xla
1053