1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <cstdint>
21 #include <functional>
22 #include <list>
23 #include <memory>
24 #include <optional>
25 #include <queue>
26 #include <set>
27 #include <sstream>
28 #include <string>
29 #include <utility>
30 #include <vector>
31
32 #include "absl/algorithm/container.h"
33 #include "absl/container/flat_hash_map.h"
34 #include "absl/container/flat_hash_set.h"
35 #include "absl/strings/numbers.h"
36 #include "absl/strings/str_cat.h"
37 #include "absl/strings/str_join.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/map_util.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
41 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
47 #include "tensorflow/compiler/xla/shape_util.h"
48 #include "tensorflow/compiler/xla/status_macros.h"
49 #include "tensorflow/compiler/xla/types.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/core/lib/core/errors.h"
52 #include "tensorflow/core/lib/core/status.h"
53 #include "tensorflow/core/platform/logging.h"
54
55 namespace xla {
56
57 using absl::StrCat;
58
Build(HloInstruction * root_instruction)59 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
60 HloInstruction* root_instruction) {
61 int parameter_count = 0;
62 for (auto& instruction : instructions_) {
63 if (instruction->opcode() == HloOpcode::kParameter) {
64 parameter_count++;
65 }
66 }
67 // If root_instruction is not specified use the last added instruction.
68 HloInstruction* root =
69 root_instruction ? root_instruction : last_added_instruction_;
70 CHECK_NE(nullptr, root);
71 return absl::WrapUnique(new HloComputation(
72 name_, parameter_count, &instructions_, root, fusion_instruction_));
73 }
74
HloComputation(const std::string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)75 HloComputation::HloComputation(
76 const std::string& name, int parameter_count,
77 std::vector<std::unique_ptr<HloInstruction>>* instructions,
78 HloInstruction* root_instruction, HloInstruction* fusion_instruction)
79 : name_(NameUniquer::GetSanitizedName(name)),
80 unique_id_(-1),
81 root_instruction_(root_instruction),
82 fusion_instruction_(fusion_instruction),
83 is_fusion_computation_(fusion_instruction != nullptr),
84 custom_call_instruction_(nullptr),
85 is_custom_call_computation_(false) {
86 param_instructions_.resize(parameter_count, nullptr);
87 bool root_found = false;
88 for (auto& instruction : *instructions) {
89 if (instruction->opcode() == HloOpcode::kParameter) {
90 int64_t param_no = instruction->parameter_number();
91 CHECK(param_no >= 0 && param_no < parameter_count)
92 << "\nERROR: invalid parameter number. Expected [0, "
93 << parameter_count << "), got " << param_no;
94 CHECK(param_instructions_[param_no] == nullptr)
95 << "\nERROR: parameter number " << param_no
96 << " already allocated in this computation";
97 param_instructions_[param_no] = instruction.get();
98 }
99 root_found |= instruction.get() == root_instruction_;
100 AddInstructionInternal(std::move(instruction));
101 }
102 CHECK(root_found)
103 << "\nERROR: root instruction is not present in computation.";
104 }
105
~HloComputation()106 HloComputation::~HloComputation() {
107 if (fusion_instruction_ != nullptr) {
108 CHECK(fusion_instruction_->fused_instructions_computation() == this);
109 fusion_instruction_->ClearCalledComputations();
110 fusion_instruction_ = nullptr;
111 }
112 if (IsAsyncComputation()) {
113 for (auto* async_instr : async_instructions_) {
114 CHECK(async_instr->async_wrapped_computation() == this);
115 async_instr->ClearCalledComputations();
116 }
117 async_instructions_.clear();
118 }
119 }
120
AddInstruction(std::unique_ptr<HloInstruction> instruction,const std::string & new_name)121 HloInstruction* HloComputation::AddInstruction(
122 std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
123 CHECK(instruction->opcode() != HloOpcode::kParameter)
124 << "Parameter instructions cannot be added to a computation after "
125 << "it has been built";
126 if (!new_name.empty()) {
127 instruction->SetAndSanitizeName(new_name);
128 }
129 return AddInstructionInternal(std::move(instruction));
130 }
131
AddInstruction(std::unique_ptr<HloInstruction> instruction,const OpMetadata * metadata)132 HloInstruction* HloComputation::AddInstruction(
133 std::unique_ptr<HloInstruction> instruction, const OpMetadata* metadata) {
134 if (metadata != nullptr) {
135 instruction->set_metadata(*metadata);
136 }
137 return AddInstruction(std::move(instruction));
138 }
139
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)140 HloInstruction* HloComputation::AddInstructionInternal(
141 std::unique_ptr<HloInstruction> instruction) {
142 if (parent() != nullptr) {
143 instruction->UniquifyName(&parent()->instruction_name_uniquer());
144 instruction->SetUniqueId(parent()->NewUniqueInstructionId());
145 }
146 instruction->set_parent(this);
147 HloInstruction* pinst = instruction.get();
148 instruction_iterators_[pinst] =
149 instructions_.insert(instructions_.end(), std::move(instruction));
150 return pinst;
151 }
152
AddParameter(std::unique_ptr<HloInstruction> instruction)153 HloInstruction* HloComputation::AddParameter(
154 std::unique_ptr<HloInstruction> instruction) {
155 CHECK(instruction->opcode() == HloOpcode::kParameter);
156 CHECK(!IsFusionComputation() ||
157 fusion_instruction_->operand_count() == param_instructions_.size());
158 instruction->set_parent(this);
159 param_instructions_.push_back(instruction.get());
160 AddInstructionInternal(std::move(instruction));
161 return instructions_.back().get();
162 }
163
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)164 HloInstruction* HloComputation::AddEntryComputationParameter(
165 std::unique_ptr<HloInstruction> instruction) {
166 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
167 CHECK_EQ(instruction->parameter_number(), num_parameters());
168 CHECK(parent()->entry_computation() == this);
169
170 HloModuleConfig config = parent()->config();
171 config.mutable_entry_computation_layout()->add_parameter_layout(
172 ShapeLayout(instruction->shape()));
173 parent()->set_config(config);
174
175 instruction->set_parent(this);
176 param_instructions_.push_back(instruction.get());
177 AddInstructionInternal(std::move(instruction));
178
179 return instructions_.back().get();
180 }
181
ReplaceEntryComputationParameter(int64_t param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)182 Status HloComputation::ReplaceEntryComputationParameter(
183 int64_t param_no, HloInstruction* old_instruction,
184 std::unique_ptr<HloInstruction> instruction) {
185 CHECK_GE(param_no, 0);
186 CHECK_LT(param_no, param_instructions_.size());
187 CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
188 CHECK(parent()->entry_computation() == this);
189
190 HloModuleConfig config = parent()->config();
191 *config.mutable_entry_computation_layout()->mutable_parameter_layout(
192 param_no) = ShapeLayout(instruction->shape());
193 parent()->set_config(config);
194
195 instruction->set_parent(this);
196 param_instructions_[param_no] = instruction.get();
197 AddInstructionInternal(std::move(instruction));
198
199 return ForceRemoveInstruction(old_instruction);
200 }
201
RemoveParameter(int64_t param_no)202 Status HloComputation::RemoveParameter(int64_t param_no) {
203 CHECK_GE(param_no, 0);
204 CHECK_LT(param_no, param_instructions_.size());
205 HloInstruction* param_instruction = param_instructions_[param_no];
206 auto param_instruction_iterator = param_instructions_.begin() + param_no;
207 param_instructions_.erase(param_instruction_iterator);
208 // Throw removed fused parameter instruction away.
209 TF_RETURN_IF_ERROR(ForceRemoveInstruction(param_instruction));
210
211 while (param_no < param_instructions_.size()) {
212 param_instruction = param_instructions_[param_no];
213 HloInstruction* new_instr =
214 AddInstructionInternal(HloInstruction::CreateParameter(
215 param_no, param_instruction->shape(), StrCat("param_", param_no)));
216 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
217 param_instructions_[param_no] = new_instr;
218 TF_RETURN_IF_ERROR(ForceRemoveInstruction(param_instruction));
219 param_no++;
220 }
221
222 return OkStatus();
223 }
224
ReplaceParameter(int64_t param_no,std::unique_ptr<HloInstruction> instruction)225 HloInstruction* HloComputation::ReplaceParameter(
226 int64_t param_no, std::unique_ptr<HloInstruction> instruction) {
227 CHECK_GE(param_no, 0);
228 CHECK_LT(param_no, param_instructions_.size());
229 CHECK(instruction->opcode() == HloOpcode::kParameter);
230 CHECK(!IsFusionComputation() ||
231 fusion_instruction_->operand_count() == param_instructions_.size());
232
233 instruction->set_parent(this);
234 HloInstruction* new_instruction =
235 AddInstructionInternal(std::move(instruction));
236 HloInstruction* old_instruction = param_instructions_[param_no];
237 CHECK(
238 old_instruction->ReplaceAllUsesWithDifferentShape(new_instruction).ok());
239 param_instructions_[param_no] = new_instruction;
240 CHECK(RemoveInstruction(old_instruction).ok());
241 return new_instruction;
242 }
243
RemoveUnusedParametersFromFusedComputation()244 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
245 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
246 }
247
RemoveUnusedParametersFromAnyComputation()248 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
249 return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
250 }
251
RemoveUnusedParametersImpl(bool allow_non_fusion)252 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
253 CHECK(allow_non_fusion || IsFusionComputation());
254 int64_t removed = 0;
255 for (int64_t i = 0; i < param_instructions_.size(); ++i) {
256 HloInstruction* param_instruction = param_instructions_[i];
257 if (param_instruction->IsDead()) {
258 TF_RETURN_IF_ERROR(
259 RemoveInstructionImpl(param_instruction, allow_non_fusion));
260 ++removed;
261 continue;
262 }
263
264 if (removed > 0) {
265 const int64_t param_no = i - removed;
266 HloInstruction* new_instr = AddInstructionInternal(
267 HloInstruction::CreateParameter(param_no, param_instruction->shape(),
268 StrCat("param_", param_no)));
269 TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
270 param_instructions_[param_no] = new_instr;
271 TF_RETURN_IF_ERROR(
272 RemoveInstructionImpl(param_instruction, allow_non_fusion));
273 }
274 }
275 param_instructions_.resize(param_instructions_.size() - removed);
276 return OkStatus();
277 }
278
IsSafelyRemovable(const HloInstruction * instruction)279 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
280 // If the instruction has control predecessors or successors then we cannot
281 // remove the instruction without violating ordering constraints (added, for
282 // example, to avert interference due to buffer aliasing).
283 if (!instruction->control_predecessors().empty() ||
284 !instruction->control_successors().empty()) {
285 return false;
286 }
287
288 if (instruction->opcode() == HloOpcode::kParameter &&
289 !IsFusionComputation()) {
290 return false;
291 }
292
293 return true;
294 }
295
HasSideEffect() const296 bool HloComputation::HasSideEffect() const {
297 for (auto* instruction : instructions()) {
298 if (instruction->HasSideEffect()) {
299 return true;
300 }
301 }
302 return false;
303 }
304
IsMarkedAsDead(const HloInstruction * inst)305 bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
306 return inst->IsMarkedAsDead();
307 }
308
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)309 Status HloComputation::RemoveInstructionAndUnusedOperands(
310 HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
311 TF_RET_CHECK(root_instruction() != instruction);
312
313 TF_RET_CHECK(instruction->IsDead());
314 TF_RET_CHECK(IsSafelyRemovable(instruction))
315 << "Cannot remove instruction: " << instruction->ToString();
316 absl::flat_hash_set<HloInstruction*> removed;
317 std::queue<HloInstruction*> worklist;
318 worklist.push(instruction);
319 while (!worklist.empty()) {
320 HloInstruction* item = worklist.front();
321 worklist.pop();
322
323 if (removed.contains(item) || !item->IsDead() || !IsSafelyRemovable(item) ||
324 (item->HasSideEffect() && item != instruction)) {
325 continue;
326 }
327 for (int i = 0; i < item->operand_count(); ++i) {
328 worklist.push(item->mutable_operand(i));
329 }
330
331 if (cleanup) {
332 cleanup(item);
333 }
334 TF_RETURN_IF_ERROR(RemoveInstruction(item));
335 removed.insert(item);
336 }
337 return OkStatus();
338 }
339
RemoveInstruction(HloInstruction * instruction)340 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
341 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
342 }
343
ForceRemoveInstruction(HloInstruction * instruction)344 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
345 return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
346 }
347
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)348 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
349 bool ignore_safety_check) {
350 VLOG(2) << "Removing instruction " << instruction->name()
351 << " from computation " << name();
352 TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
353 << "cannot remove instruction: " << instruction->ToString();
354 TF_RET_CHECK(instruction->IsDead()) << "instruction " << instruction->name()
355 << " is live and cannot be removed";
356 TF_RET_CHECK(instruction->control_predecessors().empty())
357 << "instruction " << instruction->name()
358 << " has control predecessors and cannot be removed";
359 TF_RET_CHECK(instruction->control_successors().empty())
360 << "instruction " << instruction->name()
361 << " has control successors and cannot be removed";
362
363 auto inst_it = instruction_iterators_.find(instruction);
364 TF_RET_CHECK(inst_it != instruction_iterators_.end());
365 (*inst_it->second)->set_parent(nullptr);
366 to_be_deleted_.emplace_back(inst_it->second->release());
367 to_be_deleted_.back()->DetachFromOperandsAndUsers();
368 // Clear all operands to avoid Null operands.
369 to_be_deleted_.back()->RemoveAllOperands();
370 to_be_deleted_.back()->ClearCalledComputations();
371 to_be_deleted_.back()->MarkAsDead();
372 instructions_.erase(inst_it->second);
373 instruction_iterators_.erase(inst_it);
374 return OkStatus();
375 }
376
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)377 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
378 bool accept_different_shape) {
379 // The shape of the root (ignoring layout) is an invariant of the computation
380 // for non-fusion cases.
381 if (!IsFusionComputation() && !accept_different_shape) {
382 CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
383 root_instruction_->shape()))
384 << new_root_instruction->shape() << " is incompatible with "
385 << root_instruction_->shape();
386 }
387 bool root_found = false;
388 for (auto& instruction : instructions_) {
389 if (new_root_instruction == instruction.get()) {
390 root_found = true;
391 break;
392 }
393 }
394 DCHECK(root_found);
395
396 if (parent() && parent()->has_entry_computation() &&
397 parent()->entry_computation() == this) {
398 if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
399 root_instruction_->shape())) {
400 // Rebuild input output alias config now that we have a new output shape.
401 parent()->input_output_alias_config() =
402 HloInputOutputAliasConfig(new_root_instruction->shape());
403 }
404 }
405
406 root_instruction_ = new_root_instruction;
407 }
408
409 namespace {
410
411 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)412 void ComputeComputationPostOrder(HloComputation* computation,
413 absl::flat_hash_set<HloComputation*>* visited,
414 std::vector<HloComputation*>* post_order) {
415 if (visited->insert(computation).second) {
416 for (auto* instruction : computation->instructions()) {
417 for (HloComputation* called_computation :
418 instruction->called_computations()) {
419 ComputeComputationPostOrder(called_computation, visited, post_order);
420 }
421 }
422 post_order->push_back(computation);
423 }
424 }
425
GetChannelId(const HloInstruction & inst)426 std::optional<int64_t> GetChannelId(const HloInstruction& inst) {
427 // Note that we only include Send and RecvDone, as we want to create a
428 // dependency between those, but not SendDone and Recv.
429 switch (inst.opcode()) {
430 case HloOpcode::kSend:
431 case HloOpcode::kRecvDone:
432 case HloOpcode::kAllReduce:
433 case HloOpcode::kAllGather:
434 case HloOpcode::kAllToAll:
435 case HloOpcode::kCollectivePermute:
436 case HloOpcode::kReduceScatter:
437 return inst.channel_id();
438 default:
439 return std::nullopt;
440 }
441 }
442
443 } // namespace
444
ComputeInstructionPostOrder(HloInstruction * root,HloComputation::ChannelDependencyGroup & channel_dependencies,absl::flat_hash_map<HloInstruction *,VisitState> & visited,std::vector<HloInstruction * > & post_order) const445 void HloComputation::ComputeInstructionPostOrder(
446 HloInstruction* root,
447 HloComputation::ChannelDependencyGroup& channel_dependencies,
448 absl::flat_hash_map<HloInstruction*, VisitState>& visited,
449 std::vector<HloInstruction*>& post_order) const {
450 std::vector<HloInstruction*> dfs_stack = {root};
451 while (!dfs_stack.empty()) {
452 HloInstruction& current = *dfs_stack.back();
453
454 auto result = visited.insert({¤t, kVisiting});
455 if (!result.second) { // We've already seen this instruction.
456 dfs_stack.pop_back();
457 if (result.first->second != kVisited) {
458 CHECK_EQ(current.parent(), this)
459 << "Instruction " << current.name()
460 << " is not in the current computation (" << name() << ").";
461 post_order.push_back(¤t);
462 result.first->second = kVisited;
463 }
464 continue;
465 }
466
467 // Add channel dependencies.
468 // A RecvDone op must be preceded by the corresponding Send op.
469 // Collectives with the same channel ID must be performed together, as these
470 // represent MPMD-partitioned that will later be split into separate modules
471 // and the order must be preserved.
472 std::optional<int64_t> channel_id =
473 ((¤t != root) && (current.opcode() != HloOpcode::kSend))
474 ? GetChannelId(current)
475 : std::nullopt;
476 if (channel_id) {
477 auto it = channel_dependencies.find(*channel_id);
478 if (it != channel_dependencies.end()) {
479 dfs_stack.insert(dfs_stack.end(), it->second.begin(), it->second.end());
480 channel_dependencies.erase(it);
481 }
482 }
483
484 // Add the operands to the stack in reverse order so the first operand is
485 // processed first. This will produce a more natural ordering and a nicer
486 // result for things like HLO stringification.
487 const HloInstruction::InstructionVector& operands = current.operands();
488 dfs_stack.insert(dfs_stack.end(), operands.rbegin(), operands.rend());
489
490 const std::vector<HloInstruction*>& predecessors =
491 current.control_predecessors();
492 dfs_stack.insert(dfs_stack.end(), predecessors.begin(), predecessors.end());
493 }
494 }
495
496 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const497 HloComputation::ComputeChannelDependencies() const {
498 if (parent() && parent()->config().has_static_device_assignment() &&
499 (parent()->config().static_device_assignment().computation_count() == 1 ||
500 parent()->config().use_spmd_partitioning())) {
501 return {};
502 }
503
504 ChannelDependencyGroup channel_dependencies;
505 for (const auto& instruction : instructions_) {
506 std::optional<int64_t> channel_id = GetChannelId(*instruction);
507 if (channel_id)
508 channel_dependencies[*channel_id].push_back(instruction.get());
509 }
510 return channel_dependencies;
511 }
512
HasOnlyTraceUsers(const HloInstruction * instruction)513 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
514 return absl::c_all_of(instruction->users(),
515 [](HloInstruction* user) { return false; });
516 }
517
MakeInstructionPostOrder() const518 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
519 ChannelDependencyGroup channel_dependencies = ComputeChannelDependencies();
520 std::vector<HloInstruction*> post_order;
521 post_order.reserve(instruction_count());
522 std::vector<HloInstruction*> trace_instructions;
523 absl::flat_hash_map<HloInstruction*, VisitState> visited;
524 visited.reserve(instruction_count());
525 for (auto& instruction : instructions_) {
526 if (HasOnlyTraceUsers(instruction.get())) {
527 ComputeInstructionPostOrder(instruction.get(), channel_dependencies,
528 visited, post_order);
529 }
530 }
531 post_order.insert(post_order.end(), trace_instructions.begin(),
532 trace_instructions.end());
533 CHECK_EQ(instructions_.size(), post_order.size())
534 << "number of instructions does not match post order size";
535 return post_order;
536 }
537
MakeEmbeddedComputationsList() const538 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
539 const {
540 absl::flat_hash_set<HloComputation*> visited;
541 std::vector<HloComputation*> post_order;
542
543 // To avoid special handling of this computation, cast away const of
544 // 'this'. 'this' is immediately removed from the post order after
545 // construction.
546 //
547 // TODO(b/78350259): This violates const-correctness, since while the original
548 // computation is not returned, we still retrieve non-const computations from
549 // a const one. Consider also avoiding const for HloComputation, or review XLA
550 // for const-correctness of non-HloInstruction* types like this.
551 ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
552 &post_order);
553
554 // We don't want to include this computation in the post order.
555 CHECK_EQ(this, post_order.back());
556 post_order.pop_back();
557
558 return post_order;
559 }
560
ToString(const HloPrintOptions & options) const561 std::string HloComputation::ToString(const HloPrintOptions& options) const {
562 return std::string(ToCord(options));
563 }
564
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const565 std::string HloComputation::ToString(
566 const HloPrintOptions& options,
567 absl::Span<const HloInstruction* const> instruction_order) const {
568 return std::string(ToCord(options, instruction_order));
569 }
570
ToCord(const HloPrintOptions & options) const571 absl::Cord HloComputation::ToCord(const HloPrintOptions& options) const {
572 return ToCord(options, MakeInstructionPostOrder());
573 }
574
ToCord(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const575 absl::Cord HloComputation::ToCord(
576 const HloPrintOptions& options,
577 absl::Span<const HloInstruction* const> instruction_order) const {
578 CHECK_EQ(instruction_order.size(), instruction_count());
579 const std::string tab(2 * options.indent_amount(), ' ');
580
581 absl::Cord result;
582 result.Append(tab);
583
584 if (!options.is_in_nested_computation()) {
585 if (options.print_percent()) {
586 result.Append("%");
587 }
588 if (options.print_ids()) {
589 // When print_ids() is false, exclude entry computation's name because it
590 // includes and leads to non-deterministic fingerprint.
591 result.Append(name());
592 result.Append(" ");
593 }
594 }
595
596 if (options.print_program_shape()) {
597 result.Append(
598 ShapeUtil::HumanString(ComputeProgramShape(options.print_ids())));
599 result.Append(" ");
600 }
601 result.Append("{\n");
602
603 {
604 // Print the instructions in this computation.
605 HloPrintOptions new_options =
606 HloPrintOptions(options)
607 .set_indent_amount(options.indent_amount() + 1)
608 .set_is_in_nested_computation(true);
609
610 const std::string new_tab(2 * new_options.indent_amount(), ' ');
611
612 CanonicalNameMap name_map;
613 for (const HloInstruction* const instruction : instruction_order) {
614 DCHECK_EQ(this, instruction->parent());
615 result.Append(new_tab);
616 if (instruction == root_instruction_) {
617 result.Append("ROOT ");
618 }
619 result.Append(
620 instruction->ToStringWithCanonicalNameMap(new_options, &name_map));
621 result.Append("\n");
622 }
623 }
624
625 result.Append(tab);
626 result.Append("}");
627 if (options.print_ids() && !IsMainThread()) {
628 // When print_ids() is false, exclude entry computation's thread name
629 // because it includes and leads to non-deterministic fingerprint.
630 result.Append(StrCat(", execution_thread=\"", execution_thread(), "\""));
631 }
632 return result;
633 }
634
ToProto() const635 HloComputationProto HloComputation::ToProto() const {
636 HloComputationProto proto;
637 CHECK(unique_id_ != -1)
638 << "This computation does not have a valid id. Please make sure the "
639 "computation is inside a module before dumping it.";
640 proto.set_id(unique_id_);
641 proto.set_name(name_);
642 for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
643 HloInstructionProto instruction_proto = instruction->ToProto();
644 proto.add_instructions()->Swap(&instruction_proto);
645 }
646 proto.set_root_id(root_instruction()->unique_id());
647 *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
648 proto.set_is_fusion_computation(is_fusion_computation_);
649 proto.set_execution_thread(IsMainThread() ? ""
650 : std::string(execution_thread()));
651 return proto;
652 }
653
654 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64_t,HloComputation * > & computation_map,bool prohibit_empty_literal)655 HloComputation::CreateFromProto(
656 const HloComputationProto& proto,
657 const absl::flat_hash_map<int64_t, HloComputation*>& computation_map,
658 bool prohibit_empty_literal) {
659 absl::flat_hash_map<int64_t, HloInstruction*> instruction_map;
660 absl::flat_hash_map<HloInstruction*, int64_t> to_proto_id;
661 std::vector<std::unique_ptr<HloInstruction>> instructions;
662 int64_t parameter_count = 0;
663 for (const HloInstructionProto& instruction_proto : proto.instructions()) {
664 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
665 HloInstruction::CreateFromProto(
666 instruction_proto, instruction_map, computation_map,
667 prohibit_empty_literal));
668 if (instruction->opcode() == HloOpcode::kParameter) {
669 parameter_count++;
670 }
671 TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
672 instruction_map[instruction_proto.id()] = instruction.get();
673 to_proto_id[instruction.get()] = instruction_proto.id();
674 instructions.push_back(std::move(instruction));
675 }
676
677 TF_RET_CHECK(proto.root_id() != -1);
678 TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
679 HloInstruction* root = instruction_map.at(proto.root_id());
680
681 // Sort the instructions in the proto id's order.
682 absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
683 const std::unique_ptr<HloInstruction>& b) {
684 return to_proto_id[a.get()] < to_proto_id[b.get()];
685 });
686
687 TF_RETURN_IF_ERROR([&]() -> Status {
688 std::vector<bool> parameters_seen(parameter_count);
689 int parameters_seen_count = 0;
690 for (auto& instruction : instructions) {
691 if (instruction->opcode() == HloOpcode::kParameter) {
692 int64_t param_no = instruction->parameter_number();
693 TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
694 << "Invalid parameter number. Expected [0, " << parameter_count
695 << "), got " << param_no;
696 TF_RET_CHECK(!parameters_seen[param_no])
697 << "Parameter number " << param_no
698 << " already allocated in this computation";
699 parameters_seen[param_no] = true;
700 parameters_seen_count++;
701 }
702 }
703 TF_RET_CHECK(parameters_seen_count == parameter_count)
704 << "Not all parameters in range [0, " << parameter_count
705 << ") were referenced";
706 return OkStatus();
707 }());
708
709 auto computation = absl::WrapUnique(
710 new HloComputation(proto.name(), parameter_count, &instructions, root,
711 /*fusion_instruction=*/nullptr));
712 computation->unique_id_ = proto.id();
713 computation->is_fusion_computation_ = proto.is_fusion_computation();
714 if (!proto.execution_thread().empty()) {
715 computation->SetExecutionThread(proto.execution_thread());
716 }
717 return std::move(computation);
718 }
719
AppendInstructionsIntoCalledComputation(absl::Span<HloInstruction * const> instructions_to_append,HloInstruction * caller)720 void HloComputation::AppendInstructionsIntoCalledComputation(
721 absl::Span<HloInstruction* const> instructions_to_append,
722 HloInstruction* caller) {
723 HloInstruction* root = instructions_to_append.front();
724 TF_CHECK_OK(root->ReplaceAllUsesWith(caller));
725 if (root == root_instruction()) {
726 set_root_instruction(caller);
727 }
728 TF_CHECK_OK(RemoveInstruction(root));
729 for (size_t i = 1; i < instructions_to_append.size(); ++i) {
730 HloInstruction* instruction = instructions_to_append[i];
731 caller->AppendInstructionIntoCalledComputation(instruction);
732 if (instruction->IsDead()) {
733 TF_CHECK_OK(RemoveInstruction(instruction));
734 }
735 }
736 }
737
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)738 HloInstruction* HloComputation::CreateFusionInstruction(
739 absl::Span<HloInstruction* const> instructions_to_fuse,
740 HloInstruction::FusionKind fusion_kind) {
741 HloInstruction* root = instructions_to_fuse.front();
742 HloInstruction* fusion_instruction = AddInstruction(
743 HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
744 AppendInstructionsIntoCalledComputation(instructions_to_fuse,
745 fusion_instruction);
746 return fusion_instruction;
747 }
748
CreateCallInstruction(absl::Span<HloInstruction * const> instructions_to_call)749 HloInstruction* HloComputation::CreateCallInstruction(
750 absl::Span<HloInstruction* const> instructions_to_call) {
751 HloInstruction* root = instructions_to_call.front();
752 HloInstruction* call_instruction =
753 AddInstruction(HloInstruction::CreateCall(root->shape(), root));
754 AppendInstructionsIntoCalledComputation(instructions_to_call,
755 call_instruction);
756 return call_instruction;
757 }
758
CreateAsyncInstructions(HloInstruction * instruction,absl::Span<const Shape> context_shapes,absl::string_view async_execution_thread)759 StatusOr<HloInstruction*> HloComputation::CreateAsyncInstructions(
760 HloInstruction* instruction, absl::Span<const Shape> context_shapes,
761 absl::string_view async_execution_thread) {
762 Builder builder("async_computation");
763 std::vector<HloInstruction*> parameters(instruction->operand_count());
764 std::vector<Shape> parameter_shapes(instruction->operand_count());
765 for (int i = 0; i < instruction->operand_count(); ++i) {
766 const Shape& parameter_shape = instruction->operand(i)->shape();
767 parameters[i] = builder.AddInstruction(HloInstruction::CreateParameter(
768 i, parameter_shape, absl::StrCat("param_", i)));
769 parameter_shapes[i] = parameter_shape;
770 }
771 HloInstruction* root = builder.AddInstruction(
772 instruction->CloneWithNewOperands(instruction->shape(), parameters));
773 HloComputation* async_computation =
774 parent_->AddEmbeddedComputation(builder.Build(root));
775 std::vector<Shape> start_shapes = {
776 ShapeUtil::MakeTupleShape(parameter_shapes), root->shape()};
777 for (const Shape& context_shape : context_shapes) {
778 start_shapes.push_back(context_shape);
779 }
780 HloInstruction* async_start = AddInstruction(HloInstruction::CreateAsyncStart(
781 ShapeUtil::MakeTupleShape(start_shapes), instruction->operands(),
782 async_computation, /*async_group_id=*/std::nullopt,
783 async_execution_thread));
784 HloInstruction* async_done = AddInstruction(HloInstruction::CreateAsyncDone(
785 root->shape(), async_start, async_computation,
786 /*async_group_id=*/std::nullopt, async_execution_thread));
787 async_start->set_metadata(instruction->metadata());
788 async_start->CopyBackendConfigFrom(instruction);
789 async_done->set_metadata(instruction->metadata());
790 async_done->CopyBackendConfigFrom(instruction);
791 TF_RETURN_IF_ERROR(ReplaceInstruction(instruction, async_done));
792 return async_done;
793 }
794
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)795 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
796 HloInstruction* instruction, ShapeIndex* index,
797 const std::function<
798 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
799 HloComputation* computation)>& copy_leaf) {
800 if (instruction->shape().IsTuple()) {
801 std::vector<HloInstruction*> elements;
802 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
803 i++) {
804 HloInstruction* gte =
805 AddInstruction(HloInstruction::CreateGetTupleElement(
806 ShapeUtil::GetTupleElementShape(instruction->shape(), i),
807 instruction, i));
808
809 index->push_back(i);
810 TF_ASSIGN_OR_RETURN(HloInstruction * element,
811 DeepCopyHelper(gte, index, copy_leaf));
812 elements.push_back(element);
813 index->pop_back();
814 }
815 return AddInstruction(HloInstruction::CreateTuple(elements));
816 }
817 if (instruction->shape().IsToken()) {
818 // Tokens have no on-device representation and cannot be copied. Pass
819 // through transparently.
820 return instruction;
821 }
822
823 // Array shape.
824 TF_RET_CHECK(instruction->shape().IsArray());
825 return copy_leaf(instruction, *index, this);
826 }
827
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)828 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
829 HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
830 ShapeTree<HloInstruction*>* copies_added) {
831 if (instruction->parent() != this) {
832 return FailedPrecondition(
833 "Can't deep copy instruction %s: instruction is not in computation %s",
834 instruction->name(), name());
835 }
836 if (indices_to_copy != nullptr &&
837 !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
838 return FailedPrecondition(
839 "Can't deep copy instruction %s: given shape tree of indices to copy "
840 "has incompatible shapes: %s vs. %s",
841 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
842 ShapeUtil::HumanString(indices_to_copy->shape()));
843 }
844
845 ShapeIndex index;
846 auto copy_leaf = [indices_to_copy, copies_added](
847 HloInstruction* leaf, const ShapeIndex& leaf_index,
848 HloComputation* computation) {
849 if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
850 HloInstruction* copy = computation->AddInstruction(
851 HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
852 if (copies_added != nullptr) {
853 *copies_added->mutable_element(leaf_index) = copy;
854 }
855 return copy;
856 }
857 // Elements which are not to be copied are passed through
858 // transparently.
859 return leaf;
860 };
861 return DeepCopyHelper(instruction, &index, copy_leaf);
862 }
863
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)864 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
865 HloInstruction* instruction,
866 const std::function<
867 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
868 HloComputation* computation)>& copy_leaf) {
869 if (instruction->parent() != this) {
870 return FailedPrecondition(
871 "Can't deep copy instruction %s: instruction is not in computation %s",
872 instruction->name(), name());
873 }
874 ShapeIndex index;
875 return DeepCopyHelper(instruction, &index, copy_leaf);
876 }
877
ComputeProgramShape(bool include_ids) const878 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
879 ProgramShape program_shape;
880
881 for (auto* param_instruction : param_instructions_) {
882 *program_shape.add_parameters() = param_instruction->shape();
883 *program_shape.add_parameter_names() =
884 PrintName(param_instruction->name(), include_ids);
885 }
886 *program_shape.mutable_result() = root_instruction_->shape();
887
888 return program_shape;
889 }
890
EqualInternal(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values,bool ignore_thread) const891 bool HloComputation::EqualInternal(const HloComputation& other,
892 bool is_layout_sensitive,
893 bool ignore_channel_id_values,
894 bool ignore_thread) const {
895 if (this == &other) {
896 return true;
897 }
898 absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
899 visited;
900 std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
901
902 worklist.push_back({root_instruction(), other.root_instruction()});
903
904 while (!worklist.empty()) {
905 auto pair = worklist.back();
906 worklist.pop_back();
907
908 if (visited.contains(pair)) {
909 continue;
910 }
911 visited.emplace(pair);
912 // TODO(b/123082518): Avoid recursively invoking Equal because it may
913 // cause a stack overflow with deeply nested subcomputations.
914 auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
915 return true;
916 };
917 auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
918 return a->EqualInternal(*b, is_layout_sensitive, ignore_channel_id_values,
919 ignore_thread);
920 };
921 bool identical_ignoring_operands =
922 ignore_channel_id_values
923 ? pair.first->IdenticalIgnoringChannelIdValues(
924 *pair.second, operands_eq, comp_eq, is_layout_sensitive)
925 : pair.first->Identical(*pair.second, operands_eq, comp_eq,
926 is_layout_sensitive);
927 if (!identical_ignoring_operands) {
928 return false;
929 }
930 for (size_t i = 0; i < pair.first->operands().size(); ++i) {
931 worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
932 }
933 }
934
935 if (!ignore_thread) {
936 return execution_thread() == other.execution_thread();
937 }
938 return true;
939 }
940
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)941 Status HloComputation::ReplaceWithNewInstruction(
942 HloInstruction* old_instruction,
943 std::unique_ptr<HloInstruction> new_instruction) {
944 return ReplaceInstruction(old_instruction,
945 AddInstruction(std::move(new_instruction)));
946 }
947
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)948 Status HloComputation::ReplaceWithNewEntryComputationParameter(
949 HloInstruction* old_instruction,
950 std::unique_ptr<HloInstruction> new_instruction) {
951 return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
952 std::move(new_instruction)));
953 }
954
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction,bool preserve_sharding)955 StatusOr<bool> HloComputation::ReplaceInstruction(
956 HloInstruction* old_instruction, HloInstruction* new_instruction,
957 bool preserve_sharding) {
958 TF_RET_CHECK(
959 ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
960 << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
961 << ShapeUtil::HumanString(new_instruction->shape());
962 return ReplaceInstructionWithDifferentShape(old_instruction, new_instruction,
963 preserve_sharding);
964 }
965
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)966 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
967 HloInstruction* new_instruction) {
968 TF_ASSIGN_OR_RETURN(bool changed,
969 ReplaceInstruction(old_instruction, new_instruction,
970 /*preserve_sharding=*/false));
971 DCHECK(changed);
972 return OkStatus();
973 }
974
ReplaceInstructionWithDifferentShape(HloInstruction * old_instruction,HloInstruction * new_instruction,bool preserve_sharding)975 StatusOr<bool> HloComputation::ReplaceInstructionWithDifferentShape(
976 HloInstruction* old_instruction, HloInstruction* new_instruction,
977 bool preserve_sharding) {
978 if (preserve_sharding && new_instruction->has_sharding() &&
979 old_instruction->has_sharding() &&
980 !new_instruction->has_compatible_sharding(old_instruction)) {
981 VLOG(10) << "Skipping replacement due to incompatible sharding";
982 return false;
983 }
984 VLOG(10) << "transformed " << old_instruction->ToString() << " to "
985 << new_instruction->ToString();
986 // Try to add metadata for HLO instructions that are created to replace
987 // existing HLO instructions (e.g. during optimizations). The assumption is
988 // that the old instruction and the new instruction would perform the same
989 // function, and that they would be correlated to the same TF op. This might
990 // not always be correct since HLO optimizations can cross TF op boundaries.
991 // But still this seems to be better than nothing.
992 bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
993 !old_instruction->metadata().op_name().empty();
994 bool overwrite_pass_id =
995 new_instruction->metadata().op_name().empty() &&
996 new_instruction->metadata().logical_creation_pass_id() == 0 &&
997 old_instruction->metadata().logical_creation_pass_id() != 0;
998 if (overwrite_op_name || overwrite_pass_id) {
999 new_instruction->set_metadata(old_instruction->metadata());
1000 }
1001 if (new_instruction->frontend_attributes().map().empty()) {
1002 new_instruction->set_frontend_attributes(
1003 old_instruction->frontend_attributes());
1004 }
1005
1006 // Like the metadata above, if the user didn't specify any sharding
1007 // information on the new instruction we should copy the old sharding
1008 // information (if any).
1009 if (!new_instruction->has_sharding()) {
1010 new_instruction->set_sharding(old_instruction->sharding_ptr());
1011 }
1012
1013 TF_RETURN_IF_ERROR(
1014 old_instruction->ReplaceAllUsesWithDifferentShape(new_instruction));
1015
1016 // Preserve the old instruction's name if the new and old instruction have the
1017 // same opcode. This makes it easier to follow instructions as they're
1018 // mutated through passes.
1019 if (old_instruction->opcode() == new_instruction->opcode() &&
1020 (old_instruction->opcode() != HloOpcode::kCustomCall ||
1021 old_instruction->custom_call_target() ==
1022 new_instruction->custom_call_target())) {
1023 new_instruction->SetAndSanitizeName(old_instruction->name());
1024 }
1025
1026 TF_RETURN_IF_ERROR(RemoveInstructionAndUnusedOperands(old_instruction));
1027 return true;
1028 }
1029
ReplaceInstructionWithDifferentShape(HloInstruction * old_instruction,HloInstruction * new_instruction)1030 Status HloComputation::ReplaceInstructionWithDifferentShape(
1031 HloInstruction* old_instruction, HloInstruction* new_instruction) {
1032 TF_ASSIGN_OR_RETURN(bool changed, ReplaceInstructionWithDifferentShape(
1033 old_instruction, new_instruction,
1034 /*preserve_sharding=*/false));
1035 DCHECK(changed);
1036 return OkStatus();
1037 }
1038
CollectUnreachableRoots() const1039 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
1040 std::vector<HloInstruction*> unreachable_roots;
1041 for (auto* instruction : instructions()) {
1042 if (instruction->IsDead() && instruction->control_successors().empty()) {
1043 unreachable_roots.push_back(instruction);
1044 }
1045 }
1046 VLOG(3) << "Unreachable roots:"
1047 << absl::StrJoin(unreachable_roots, "\n\t",
1048 [](std::string* out, const HloInstruction* hlo) {
1049 absl::StrAppend(out, hlo->ToString());
1050 });
1051 return unreachable_roots;
1052 }
1053
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const1054 Status HloComputation::AcceptWithOperandOrder(
1055 DfsHloVisitor* visitor,
1056 const HloInstruction::CompareFunction& operand_order) const {
1057 // Visit unreachable roots. Beware that the visitor might delete the currently
1058 // visited root, which would invalidate iterators if the unreachable roots
1059 // weren't computed ahead of time.
1060 for (HloInstruction* root : CollectUnreachableRoots()) {
1061 TF_RETURN_IF_ERROR(
1062 root->AcceptWithOperandOrder(visitor, operand_order,
1063 /*call_finish_visit=*/false));
1064 }
1065 // Visit the computation root instruction last.
1066 return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
1067 /*call_finish_visit=*/true);
1068 }
1069
Clone(const std::string & suffix,HloCloneContext * context)1070 std::unique_ptr<HloComputation> HloComputation::Clone(
1071 const std::string& suffix, HloCloneContext* context) {
1072 return CloneWithReplacements(
1073 /*replacements=*/nullptr,
1074 /*extra_parameters=*/{}, context, suffix);
1075 }
1076
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const std::string & suffix)1077 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1078 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1079 HloCloneContext* context, const std::string& suffix) {
1080 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1081 replacements;
1082 replacements.emplace(std::move(r1));
1083 return CloneWithReplacements(&replacements, /*extra_parameters=*/{}, context,
1084 suffix);
1085 }
1086
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const std::string & suffix)1087 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1088 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1089 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1090 HloCloneContext* context, const std::string& suffix) {
1091 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1092 replacements;
1093 replacements.emplace(std::move(r1));
1094 replacements.emplace(std::move(r2));
1095 return CloneWithReplacements(&replacements, /*extra_parameters=*/{}, context,
1096 suffix);
1097 }
1098
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const std::string & suffix)1099 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1100 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1101 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1102 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
1103 HloCloneContext* context, const std::string& suffix) {
1104 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1105 replacements;
1106 replacements.emplace(std::move(r1));
1107 replacements.emplace(std::move(r2));
1108 replacements.emplace(std::move(r3));
1109 return CloneWithReplacements(&replacements, /*extra_parameters=*/{}, context,
1110 suffix);
1111 }
1112
1113 namespace {
1114
1115 // Sorts unordered_instructions according to the order of ordered_instructions,
1116 // using MappedPtrContainerSorter. context and replace are used to map
1117 // instructions in ordered_instructions to instructions in
1118 // unordered_instructions. Unmapped parameter instructions are placed just after
1119 // the last parameter instruction in the sorted mapped instruction order. All
1120 // other mapped instructions are placed at the end.
SortClonedInstructions(const HloCloneContext & context,const std::function<const HloInstruction * (const HloInstruction *)> & replace,const HloComputation & computation,const HloComputation::InstructionList & ordered_instructions,std::vector<std::unique_ptr<HloInstruction>> & unordered_instructions)1121 void SortClonedInstructions(
1122 const HloCloneContext& context,
1123 const std::function<const HloInstruction*(const HloInstruction*)>& replace,
1124 const HloComputation& computation,
1125 const HloComputation::InstructionList& ordered_instructions,
1126 std::vector<std::unique_ptr<HloInstruction>>& unordered_instructions) {
1127 using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
1128 InstructionSorter::MapPtrFn instruction_mapper =
1129 [&context, &replace](const HloInstruction* i) {
1130 return context.FindInstruction(replace(i));
1131 };
1132 size_t num_mapped_instructions = 0;
1133 size_t mapped_index_of_last_parameter_plus_one = 0;
1134 for (const auto& instruction : ordered_instructions) {
1135 if (!instruction_mapper(instruction.get())) {
1136 continue;
1137 }
1138 ++num_mapped_instructions;
1139 if (!dynamic_cast<const HloParameterInstruction*>(instruction.get())) {
1140 continue;
1141 }
1142 mapped_index_of_last_parameter_plus_one = num_mapped_instructions;
1143 }
1144 InstructionSorter::UnmappedPtrIndexFn unmapped_ptr_index =
1145 [num_mapped_instructions,
1146 mapped_index_of_last_parameter_plus_one](const HloInstruction* i) {
1147 if (dynamic_cast<const HloParameterInstruction*>(i)) {
1148 if (num_mapped_instructions > 0 &&
1149 mapped_index_of_last_parameter_plus_one > 0) {
1150 return mapped_index_of_last_parameter_plus_one - 1;
1151 }
1152 return InstructionSorter::IndexBeforeMappedElementsFn()(i);
1153 }
1154 return InstructionSorter::IndexAfterMappedElementsFn()(i);
1155 };
1156 auto status =
1157 InstructionSorter::Sort(instruction_mapper, unmapped_ptr_index,
1158 ordered_instructions, unordered_instructions);
1159 if (!status.ok()) {
1160 LOG(ERROR) << "Failed to reorder instructions while cloning computation: "
1161 << computation.name() << "; " << status;
1162 }
1163 }
1164
1165 // For cloned instructions, sorts their users, control predecessors, and control
1166 // successors, according to the orders of those lists in the original
1167 // instructions, before cloning. context and replace help us to map original
1168 // instructions to cloned instructions, in addition to creating a list of
1169 // cloned instructions.
SortClonedInstructionUsersAndControlLists(const HloCloneContext & context,const std::function<const HloInstruction * (const HloInstruction *)> & replace,const HloComputation::InstructionList & sorted_instructions)1170 void SortClonedInstructionUsersAndControlLists(
1171 const HloCloneContext& context,
1172 const std::function<const HloInstruction*(const HloInstruction*)>& replace,
1173 const HloComputation::InstructionList& sorted_instructions) {
1174 using InstructionSorter = MappedPtrContainerSorter<HloInstruction>;
1175 InstructionSorter::MapPtrFn instruction_mapper =
1176 [&context, &replace](const HloInstruction* i) {
1177 return context.FindInstruction(replace(i));
1178 };
1179 for (const std::unique_ptr<HloInstruction>& instruction :
1180 sorted_instructions) {
1181 HloInstruction* cloned_instruction =
1182 context.FindInstruction(replace(instruction.get()));
1183 if (!cloned_instruction) {
1184 continue;
1185 }
1186 cloned_instruction->SortInstructionUsersAndControlLists(instruction_mapper,
1187 *instruction);
1188 }
1189 }
1190
1191 } // namespace
1192
CloneWithReplacements(const absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> * replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const std::string & suffix,const HloInstruction * new_root)1193 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
1194 const absl::flat_hash_map<const HloInstruction*,
1195 std::unique_ptr<HloInstruction>>* replacements,
1196 absl::Span<const HloInstruction* const> extra_parameters,
1197 HloCloneContext* context, const std::string& suffix,
1198 const HloInstruction* new_root) {
1199 std::unique_ptr<HloCloneContext> context_ptr;
1200 if (context == nullptr) {
1201 context_ptr = std::make_unique<HloCloneContext>(parent(), suffix);
1202 context = context_ptr.get();
1203 }
1204 if (new_root == nullptr) {
1205 new_root = root_instruction();
1206 }
1207
1208 // Look up instr in the replacements map, and return either the replacement,
1209 // or instr, if the replacement isn't present.
1210 //
1211 // Note: This can return null, indicating that instr should not be present in
1212 // the new computation.
1213 auto replace = [&](const HloInstruction* instr) {
1214 if (!replacements) return instr;
1215 auto it = replacements->find(instr);
1216 return it != replacements->end() ? it->second.get() : instr;
1217 };
1218
1219 VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1220
1221 // We want to do a postorder walk over [replace(i) for i in instructions_].
1222 // We can't reuse MakeInstructionPostOrder() for this, because that will
1223 // generate a postorder of plain instructions_, and our replacements may
1224 // change the postorder!
1225 //
1226 // The postorder we want here is simpler than what MakeInstructionPostOrder()
1227 // does -- we only care about operand dependencies -- so let's just do it
1228 // ourselves.
1229 std::vector<const HloInstruction*> postorder;
1230 absl::flat_hash_map<const HloInstruction*, VisitState> visited;
1231 for (const auto& instr : instructions_) {
1232 std::vector<const HloInstruction*> dfs_stack;
1233 const HloInstruction* new_instr = replace(instr.get());
1234 if (!new_instr) {
1235 continue;
1236 }
1237 dfs_stack.push_back(new_instr);
1238
1239 while (!dfs_stack.empty()) {
1240 auto* cur = dfs_stack.back();
1241 auto it = visited.find(cur);
1242 if (it != visited.end()) {
1243 dfs_stack.pop_back();
1244 if (it->second == kVisited) {
1245 continue;
1246 }
1247 CHECK_EQ(it->second, kVisiting);
1248 postorder.push_back(cur);
1249 it->second = kVisited;
1250 continue;
1251 }
1252
1253 visited.insert({cur, kVisiting});
1254 for (HloInstruction* operand : cur->operands()) {
1255 const HloInstruction* new_operand = replace(operand);
1256 if (new_operand) {
1257 dfs_stack.emplace_back(new_operand);
1258 }
1259 }
1260 }
1261 }
1262
1263 std::vector<std::unique_ptr<HloInstruction>> instructions;
1264 // First add the extra parameters to 'instructions'.
1265 for (const auto& instr : extra_parameters) {
1266 CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1267 << "Only parameter instructions are allowed in 'extra_parameters'";
1268 instructions.emplace_back(instr->Clone());
1269 }
1270 for (auto instr : postorder) {
1271 std::vector<HloInstruction*> new_operands;
1272 for (auto operand : instr->operands()) {
1273 auto replaced_operand = replace(operand);
1274 CHECK_NE(replaced_operand, nullptr)
1275 << "replacements map tried to eliminate a used instruction "
1276 << operand->ToString() << ", used by " << instr->ToString();
1277 new_operands.push_back(context->GetInstruction(replaced_operand));
1278 }
1279 std::unique_ptr<HloInstruction> new_instr =
1280 instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1281 if (instr->opcode() == HloOpcode::kParameter &&
1282 instr->parameter_replicated_at_leaf_buffers().has_value()) {
1283 new_instr->set_parameter_replicated_at_leaf_buffers(
1284 instr->parameter_replicated_at_leaf_buffers().value());
1285 }
1286 instructions.push_back(std::move(new_instr));
1287 }
1288
1289 // To make clone behavior match uncloned behavior, we reorder instructions to
1290 // match the order in instructions_.
1291 SortClonedInstructions(*context, replace, *this, instructions_, instructions);
1292
1293 Builder builder(suffix.empty() ? name() : name() + "." + suffix);
1294 for (auto& instr : instructions) {
1295 builder.AddInstruction(std::move(instr));
1296 }
1297 auto result = builder.Build(
1298 /*root_instruction=*/context->GetInstruction(replace(new_root)));
1299
1300 // Clone control dependencies.
1301 for (auto instr : postorder) {
1302 HloInstruction* new_instr = context->GetInstruction(instr);
1303 for (auto successor : instr->control_successors()) {
1304 auto replaced_successor = replace(successor);
1305 // successor may not have been remapped, because it might have been
1306 // removed by the replacements map.
1307 if (replaced_successor != nullptr) {
1308 TF_CHECK_OK(new_instr->AddControlDependencyTo(
1309 context->GetInstruction(replaced_successor)));
1310 }
1311 }
1312 }
1313
1314 // To make clone behavior match uncloned behavior, we reorder the user and
1315 // control lists, kept by cloned instructions.
1316 SortClonedInstructionUsersAndControlLists(*context, replace, instructions_);
1317
1318 context->MapComputation(this, result.get());
1319 result->SetExecutionThread(execution_thread());
1320
1321 return result;
1322 }
1323
UniquifyName(NameUniquer * name_uniquer)1324 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1325 name_ = name_uniquer->GetUniqueName(name_);
1326 }
1327
GetInstructionWithName(absl::string_view name)1328 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1329 auto instructions_in_computation = instructions();
1330 auto it = absl::c_find_if(
1331 instructions_in_computation,
1332 [&](HloInstruction* instr) { return instr->name() == name; });
1333 return it == instructions_in_computation.end() ? nullptr : *it;
1334 }
1335
IsEntryComputation() const1336 bool HloComputation::IsEntryComputation() const {
1337 return parent()->entry_computation() == this;
1338 }
1339 } // namespace xla
1340