1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include <iterator>
19 #include <set>
20 #include <sstream>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38 #include "tensorflow/core/platform/types.h"
39
40 namespace xla {
41
HloModule(const string & name,HloModuleConfig config)42 HloModule::HloModule(const string& name, HloModuleConfig config)
43 : name_(NameUniquer::GetSanitizedName(name)),
44 config_(std::move(config)),
45 unique_id_(next_unique_module_id_++) {}
46
set_schedule(HloSchedule schedule)47 Status HloModule::set_schedule(HloSchedule schedule) {
48 TF_RET_CHECK(schedule.module() == this);
49 TF_RETURN_IF_ERROR(schedule.Verify());
50 schedule_ = std::move(schedule);
51 return Status::OK();
52 }
53
ReplaceEntryComputation(HloComputation * entry_computation)54 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
55 entry_computation_ = entry_computation;
56 config_.SetDefaultComputationLayout(
57 entry_computation_->ComputeProgramShape());
58 input_output_alias_config_ = HloInputOutputAliasConfig(
59 entry_computation_->root_instruction()->shape());
60 }
61
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers)62 HloComputation* HloModule::AddComputationInternal(
63 std::unique_ptr<HloComputation> computation, bool is_entry,
64 bool uniquify_identifiers) {
65 if (is_entry) {
66 CHECK_EQ(nullptr, entry_computation_);
67 entry_computation_ = computation.get();
68
69 // If the module configuration has no entry layout computation set, create a
70 // default one based on the program shape.
71 if (!config_.has_entry_computation_layout()) {
72 config_.SetDefaultComputationLayout(
73 entry_computation_->ComputeProgramShape());
74 }
75 input_output_alias_config_ = HloInputOutputAliasConfig(
76 entry_computation_->root_instruction()->shape());
77 }
78
79 if (uniquify_identifiers) {
80 computation->UniquifyName(&computation_name_uniquer_);
81 for (auto* instruction : computation->instructions()) {
82 instruction->UniquifyName(&instruction_name_uniquer_);
83 }
84
85 // Pick unique IDs for each instruction.
86 for (auto* instruction : computation->instructions()) {
87 instruction->SetUniqueId(NewUniqueInstructionId());
88 }
89 // Set unique id to this computation.
90 CHECK_NE(computation->root_instruction()->unique_id(), -1)
91 << "Root has no valid id: " << computation->ToString();
92 computation->SetUniqueId(computation->root_instruction()->unique_id());
93 } else {
94 // Don't uniquify the names of the computation or instruction, but we must
95 // run the names through the uniquifiers to prevent future name collisions
96 // for computations and instructions created later. Also, set the
97 // next_unique_id_ to the one greater than the max unique id of any
98 // instruction (or the computation) to avoid ID collisions.
99 computation_name_uniquer_.GetUniqueName(computation->name());
100 for (auto* instruction : computation->instructions()) {
101 instruction_name_uniquer_.GetUniqueName(instruction->name());
102 next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
103 }
104 if (next_unique_id_ < computation->unique_id() + 1) {
105 next_unique_id_ = computation->unique_id() + 1;
106 }
107 }
108
109 computation->set_parent(this);
110 computations_.push_back(std::move(computation));
111 return computations_.back().get();
112 }
113
AddEntryComputation(std::unique_ptr<HloComputation> computation)114 HloComputation* HloModule::AddEntryComputation(
115 std::unique_ptr<HloComputation> computation) {
116 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
117 /*uniquify_identifiers=*/true);
118 }
119
RemoveEmbeddedComputation(HloComputation * to_remove)120 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
121 if (has_schedule() && !to_remove->IsFusionComputation()) {
122 schedule_->remove_computation(to_remove);
123 }
124
125 auto it = absl::c_find_if(
126 computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
127 return comp.get() == to_remove;
128 });
129 TF_RET_CHECK(it != computations_.end());
130 TF_RET_CHECK(it->get() == to_remove);
131 computations_.erase(it);
132 return Status::OK();
133 }
134
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)135 HloComputation* HloModule::AddEmbeddedComputation(
136 std::unique_ptr<HloComputation> computation) {
137 return AddComputationInternal(std::move(computation), /*is_entry=*/false,
138 /*uniquify_identifiers=*/true);
139 }
140
ReplaceComputations(const std::unordered_map<HloComputation *,HloComputation * > & replacements)141 void HloModule::ReplaceComputations(
142 const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
143 // Replace all uses of non-canonical computations with their
144 // representatives.
145 std::vector<std::unique_ptr<HloComputation>> new_computations;
146 new_computations.reserve(computations_.size());
147
148 for (std::unique_ptr<HloComputation>& computation : computations_) {
149 for (auto* instruction : computation->instructions()) {
150 switch (instruction->opcode()) {
151 case HloOpcode::kAllReduce:
152 case HloOpcode::kCall:
153 case HloOpcode::kMap:
154 case HloOpcode::kReduce:
155 case HloOpcode::kReduceWindow:
156 case HloOpcode::kScatter:
157 case HloOpcode::kSort: {
158 HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
159 replacements, instruction->to_apply(), nullptr);
160 if (new_arg != nullptr) {
161 instruction->set_to_apply(new_arg);
162 }
163 break;
164 }
165 case HloOpcode::kWhile: {
166 HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
167 replacements, instruction->while_condition(), nullptr);
168 if (new_condition != nullptr) {
169 instruction->set_while_condition(new_condition);
170 }
171 HloComputation* new_body = tensorflow::gtl::FindWithDefault(
172 replacements, instruction->while_body(), nullptr);
173 if (new_body != nullptr) {
174 instruction->set_while_body(new_body);
175 }
176 break;
177 }
178 case HloOpcode::kConditional: {
179 for (int b = 0; b < instruction->branch_count(); ++b) {
180 HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
181 replacements, instruction->branch_computation(b), nullptr);
182 if (new_computation != nullptr) {
183 instruction->set_branch_computation(b, new_computation);
184 }
185 }
186 break;
187 }
188 case HloOpcode::kSelectAndScatter: {
189 HloComputation* new_select = tensorflow::gtl::FindWithDefault(
190 replacements, instruction->select(), nullptr);
191 if (new_select != nullptr) {
192 instruction->set_select(new_select);
193 }
194 HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
195 replacements, instruction->scatter(), nullptr);
196 if (new_scatter != nullptr) {
197 instruction->set_scatter(new_scatter);
198 }
199 break;
200 }
201 default:
202 break;
203 }
204 }
205
206 if (replacements.find(computation.get()) == replacements.end()) {
207 new_computations.push_back(std::move(computation));
208 }
209 }
210
211 // Replace entry_computation if necessary.
212 entry_computation_ = tensorflow::gtl::FindWithDefault(
213 replacements, entry_computation_, entry_computation_);
214
215 computations_ = std::move(new_computations);
216 }
217
ToString(const HloPrintOptions & options) const218 string HloModule::ToString(const HloPrintOptions& options) const {
219 std::ostringstream s;
220 s << "HloModule " << PrintName(name(), options.print_ids());
221 if (has_schedule()) {
222 TF_CHECK_OK(schedule().Verify());
223 s << ", is_scheduled=true";
224 }
225 s << "\n\n";
226 const auto& computations = options.canonicalize_computations()
227 ? MakeComputationSortedByContent()
228 : MakeComputationPostOrder();
229 for (const HloComputation* computation : computations) {
230 if (!options.print_computation(computation)) {
231 continue;
232 }
233 if (computation == entry_computation()) {
234 s << "ENTRY ";
235 }
236 if (has_schedule() && schedule().is_computation_scheduled(computation)) {
237 s << computation->ToString(
238 options, schedule().sequence(computation).instructions())
239 << "\n\n";
240 } else {
241 s << computation->ToString(options) << "\n\n";
242 }
243 }
244 return s.str();
245 }
246
ToProto() const247 HloModuleProto HloModule::ToProto() const {
248 HloModuleProto proto;
249 proto.set_id(unique_id_);
250 proto.set_name(name_);
251 proto.set_entry_computation_name(entry_computation_->name());
252 proto.set_entry_computation_id(entry_computation_->unique_id());
253 for (const HloComputation* computation : MakeComputationPostOrder()) {
254 HloComputationProto computation_proto = computation->ToProto();
255 proto.add_computations()->Swap(&computation_proto);
256 }
257 if (has_schedule()) {
258 *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
259 }
260 *proto.mutable_host_program_shape() =
261 entry_computation_layout().ComputeProgramShape().ToProto();
262 *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
263 *proto.mutable_dynamic_parameter_binding() =
264 dynamic_parameter_binding().ToProto();
265 return proto;
266 }
267
CheckUniqueNamesAndIdsForComputationsAndInstructions() const268 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
269 absl::flat_hash_set<string> computation_names;
270 absl::flat_hash_set<int> computation_ids;
271 absl::flat_hash_set<string> instruction_names;
272 absl::flat_hash_set<int> instruction_ids;
273
274 for (const HloComputation* computation : computations()) {
275 TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
276 << "Computation name is not unique: " << computation->name();
277 computation_names.insert(computation->name());
278
279 TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
280 << "Computation id is not unique: " << computation->unique_id();
281 computation_ids.insert(computation->unique_id());
282
283 for (const HloInstruction* instruction : computation->instructions()) {
284 TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
285 << "Instruction name is not unique: " << instruction->name();
286 instruction_names.insert(instruction->name());
287
288 TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
289 << "Instruction id is not unique: " << instruction->unique_id();
290 instruction_ids.insert(instruction->unique_id());
291 }
292 }
293 return Status::OK();
294 }
295
296 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)297 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
298 const HloModuleProto& proto, const HloModuleConfig& module_config,
299 bool prohibit_empty_literal) {
300 VLOG(2) << "CreateFromProto()";
301 XLA_VLOG_LINES(3, proto.DebugString());
302
303 // The ProgramShape in the passed in module config must match the shapes of
304 // the entry parameters and root.
305 TF_RET_CHECK(proto.has_host_program_shape())
306 << "No program shape found in the proto";
307 ProgramShape expected_program_shape(proto.host_program_shape());
308 TF_RET_CHECK(expected_program_shape.parameters_size() ==
309 module_config.entry_computation_layout().parameter_count());
310 for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
311 const Shape& parameter_shape =
312 module_config.entry_computation_layout().parameter_layout(i).shape();
313 TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
314 parameter_shape))
315 << "HloModuleConfig has different shape for parameter " << i
316 << " than the HLO module. Expected: "
317 << ShapeUtil::HumanStringWithLayout(
318 expected_program_shape.parameters(i))
319 << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
320 }
321 const Shape& result_shape =
322 module_config.entry_computation_layout().result_layout().shape();
323 TF_RET_CHECK(
324 ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
325 << "HloModuleConfig has different result shape than the HLO module. "
326 "Expected: "
327 << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
328 << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
329
330 absl::flat_hash_map<int64, HloComputation*> computation_map;
331 absl::flat_hash_map<HloComputation*, int64> to_proto_id;
332 std::vector<std::unique_ptr<HloComputation>> computations;
333 HloComputation* entry = nullptr;
334 for (const HloComputationProto& computation_proto : proto.computations()) {
335 TF_ASSIGN_OR_RETURN(
336 std::unique_ptr<HloComputation> computation,
337 HloComputation::CreateFromProto(computation_proto, computation_map,
338 prohibit_empty_literal));
339 CHECK_NE(computation.get(), nullptr);
340 int64 computation_id = computation_proto.id();
341 TF_RET_CHECK(computation_id != -1);
342 TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
343 computation_map[computation_id] = computation.get();
344 to_proto_id[computation.get()] = computation_id;
345 if (computation_id == proto.entry_computation_id()) {
346 entry = computation.get();
347 }
348 computations.push_back(std::move(computation));
349 }
350 TF_RET_CHECK(entry != nullptr);
351
352 auto module = absl::make_unique<HloModule>(proto.name(), module_config);
353
354 // Sort the computations in the proto id's order.
355 absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
356 const std::unique_ptr<HloComputation>& b) {
357 return to_proto_id[a.get()] < to_proto_id[b.get()];
358 });
359
360 // Add sorted computations to the module.
361 for (auto& computation : computations) {
362 bool is_entry = computation.get() == entry;
363 // Don't uniquify names because we want names to be stable across
364 // serialization and deserialization.
365 module->AddComputationInternal(std::move(computation), is_entry,
366 /*uniquify_identifiers=*/false);
367 }
368 TF_RET_CHECK(module->entry_computation_ != nullptr);
369
370 TF_ASSIGN_OR_RETURN(
371 module->input_output_alias_config_,
372 HloInputOutputAliasConfig::CreateFromProto(
373 entry->ComputeProgramShape().result(), proto.input_output_alias()));
374
375 // Because we didn't uniquify the names or the ids, double-check that the
376 // instruction and computation names and ids are unique from the proto.
377 TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
378 DynamicParameterBinding::CreateFromProto(
379 proto.dynamic_parameter_binding()));
380
381 TF_RETURN_IF_ERROR(
382 module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
383
384 if (proto.has_schedule()) {
385 TF_ASSIGN_OR_RETURN(
386 HloSchedule schedule,
387 HloSchedule::CreateFromProto(module.get(), proto.schedule()));
388 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
389 }
390
391 return std::move(module);
392 }
393
394 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)395 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
396 const ProgramShape& program_shape, const DebugOptions& debug_options,
397 const ExecutionOptions* execution_options) {
398 HloModuleConfig module_config(ProgramShape{program_shape});
399 module_config.set_debug_options(debug_options);
400 if (execution_options) {
401 if (execution_options->num_replicas() > 0) {
402 module_config.set_replica_count(execution_options->num_replicas());
403 }
404 if (execution_options->num_partitions() > 0) {
405 module_config.set_num_partitions(execution_options->num_partitions());
406 }
407 if (execution_options->has_device_assignment()) {
408 TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
409 DeviceAssignment::Deserialize(
410 execution_options->device_assignment()));
411 module_config.set_static_device_assignment(*device_assignment);
412 if (execution_options->num_replicas() > 0) {
413 CHECK_EQ(module_config.static_device_assignment().replica_count(),
414 module_config.replica_count());
415 }
416 if (execution_options->num_partitions() > 0) {
417 CHECK_EQ(module_config.static_device_assignment().computation_count(),
418 module_config.num_partitions());
419 }
420 }
421 }
422
423 // The module config is constructed with default layouts regardless of what is
424 // passed in via the ProgramShape. Set the layouts to the appropriate values.
425 ComputationLayout* entry_layout =
426 module_config.mutable_entry_computation_layout();
427 for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
428 TF_RETURN_IF_ERROR(
429 entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
430 program_shape.parameters(i)));
431 }
432 TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
433 program_shape.result()));
434 return module_config;
435 }
436
437 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)438 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
439 const HloModuleProto& module, const DebugOptions& debug_options,
440 const ExecutionOptions* execution_options) {
441 TF_RET_CHECK(module.has_host_program_shape())
442 << "No program shape found in the proto";
443 ProgramShape program_shape(module.host_program_shape());
444 return CreateModuleConfigFromShape(program_shape, debug_options,
445 execution_options);
446 }
447
448 namespace {
449 // Returns whether `hlo` is used outside the given subcomputation.
450 // `instructions_in_subcomputation` is the instruction set of the given
451 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)452 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
453 const absl::flat_hash_set<HloInstruction*>&
454 instructions_in_subcomputation) {
455 return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
456 return !instructions_in_subcomputation.contains(user);
457 });
458 }
459 } // anonymous namespace
460
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)461 HloInstruction* HloModule::OutlineExpressionFromComputation(
462 absl::Span<HloInstruction* const> instructions_to_outline,
463 const string& outlined_computation_name, HloComputation* computation) {
464 auto builder = HloComputation::Builder(outlined_computation_name);
465
466 // A map from original instructions to their counterparts in the new outlined
467 // function.
468 absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
469 // A set that contains all instructions to be outlined.
470 absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
471 instructions_to_outline.begin(), instructions_to_outline.end());
472 std::vector<HloInstruction*> arguments;
473 std::vector<HloInstruction*> outputs;
474 int64 parameter_count = 0;
475 for (HloInstruction* instruction_to_outline : instructions_to_outline) {
476 // Clone the original instruction.
477 HloInstruction* outlined_instruction =
478 builder.AddInstruction(instruction_to_outline->Clone());
479
480 // Replace its operands to their counterparts in the new function.
481 for (int64 operand_num = 0;
482 operand_num < outlined_instruction->operand_count(); ++operand_num) {
483 HloInstruction* old_operand =
484 outlined_instruction->mutable_operand(operand_num);
485
486 HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
487 if (*operand_slot == nullptr) {
488 // Because instructions_to_outline is in topological order, if
489 // old_operand is not in outlined_instructions, old_operand must be an
490 // input of the outlined subcomputation and thus should be represented
491 // as a parameter in the new function.
492 arguments.push_back(old_operand);
493 *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
494 parameter_count, old_operand->shape(), "p"));
495 ++parameter_count;
496 }
497 TF_CHECK_OK(
498 outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
499 }
500
501 // Insert the new instruction into the outlined_instructions map.
502 InsertOrDie(&outlined_instructions, instruction_to_outline,
503 outlined_instruction);
504
505 // Mark instruction_to_outline an output if it is used outside the
506 // subcomputation or is the output of the original computation (i.e. used
507 // externally).
508 if (instruction_to_outline->user_count() == 0 ||
509 IsUsedOutsideSubcomputation(*instruction_to_outline,
510 instruction_set_to_outline)) {
511 outputs.push_back(instruction_to_outline);
512 }
513 }
514
515 if (outputs.size() != 1) {
516 string error_message =
517 "The subcomputation to outline has multiple outputs:\n";
518 for (HloInstruction* output : outputs) {
519 absl::StrAppend(&error_message, output->ToString(), "\n");
520 }
521 LOG(FATAL) << error_message;
522 }
523 HloInstruction* output = outputs[0];
524
525 // Creates a call to the nested computation.
526 HloComputation* nested_computation = AddEmbeddedComputation(
527 builder.Build(FindOrDie(outlined_instructions, output)));
528 HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
529 output->shape(), arguments, nested_computation));
530
531 VLOG(2) << "Outlining the following instructions";
532 for (auto* instruction_to_outline : instructions_to_outline) {
533 VLOG(2) << " " << instruction_to_outline->ToString();
534 }
535 VLOG(2) << "as a call " << call->ToString();
536 VLOG(2) << "to " << nested_computation->ToString();
537
538 TF_CHECK_OK(output->ReplaceAllUsesWith(call));
539 for (auto i = instructions_to_outline.rbegin();
540 i != instructions_to_outline.rend(); ++i) {
541 TF_CHECK_OK(computation->RemoveInstruction(*i));
542 }
543
544 return call;
545 }
546
instruction_count() const547 int64 HloModule::instruction_count() const {
548 int64 n = 0;
549 for (const auto& computation : computations_) {
550 n += computation->instruction_count();
551 }
552 return n;
553 }
554
MakeComputationPostOrder() const555 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
556 // First determine all root computations by building a set of nonroot
557 // computations (computations which are called by an instruction in the
558 // module).
559 absl::flat_hash_set<HloComputation*> nonroot_computations;
560 for (auto& computation : computations_) {
561 for (auto* instruction : computation->instructions()) {
562 for (HloComputation* called_computation :
563 instruction->called_computations()) {
564 nonroot_computations.insert(called_computation);
565 }
566 }
567 }
568
569 // Keep track of computations which have already been added to the post
570 // order. This prevents duplication as an embedded computation may be called
571 // from two different root computations.
572 absl::flat_hash_set<HloComputation*> added_computations;
573 std::vector<HloComputation*> post_order;
574 for (auto& computation : computations_) {
575 if (!nonroot_computations.contains(computation.get())) {
576 for (HloComputation* embedded_computation :
577 computation->MakeEmbeddedComputationsList()) {
578 if (!added_computations.contains(embedded_computation)) {
579 post_order.push_back(embedded_computation);
580 added_computations.insert(embedded_computation);
581 }
582 }
583 // Root computations should only be encountered once.
584 CHECK(!added_computations.contains(computation.get()));
585 post_order.push_back(computation.get());
586 added_computations.insert(computation.get());
587 }
588 }
589 if (post_order.size() != computations_.size()) {
590 for (HloComputation* computation : post_order) {
591 LOG(ERROR) << "Post Order: " << computation->name() << " ("
592 << computation->parent()->name() << ")";
593 }
594 for (auto& computation : computations_) {
595 LOG(ERROR) << "Computations: " << computation->name() << " ("
596 << computation->parent()->name() << ")";
597 }
598 LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
599 << " computation_count=" << computations_.size();
600 }
601 return post_order;
602 }
603
MakeComputationSortedByContent() const604 std::vector<HloComputation*> HloModule::MakeComputationSortedByContent() const {
605 auto result = MakeComputationPostOrder();
606 std::sort(result.begin(), result.end(),
607 [](HloComputation* a, HloComputation* b) {
608 if (a->instruction_count() != b->instruction_count()) {
609 return a->instruction_count() < b->instruction_count();
610 }
611 return a->ToString(HloPrintOptions::Fingerprint()) <
612 b->ToString(HloPrintOptions::Fingerprint());
613 });
614 return result;
615 }
616
MakeNonfusionComputations() const617 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
618 std::vector<HloComputation*> result;
619 for (auto* c : computations()) {
620 if (c->IsFusionComputation()) {
621 continue;
622 }
623 result.push_back(c);
624 }
625 return result;
626 }
627
MakeNonfusionComputationsSorted() const628 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
629 const {
630 auto result = MakeNonfusionComputations();
631 std::sort(result.begin(), result.end(),
632 [](HloComputation* a, HloComputation* b) {
633 return a->name() < b->name();
634 });
635 return result;
636 }
637
Clone(const string & suffix) const638 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
639 return Clone(config(), suffix);
640 }
641
Clone(const HloModuleConfig & config,const string & suffix) const642 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
643 const string& suffix) const {
644 VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
645 auto module = absl::make_unique<HloModule>(
646 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
647
648 HloCloneContext context(module.get(), suffix);
649 auto cloned_computation = entry_computation_->Clone(suffix, &context);
650 module->AddEntryComputation(std::move(cloned_computation));
651
652 if (has_schedule() && schedule().Verify().ok()) {
653 HloSchedule clone_schedule(module.get());
654 for (HloComputation* computation : computations()) {
655 if (schedule().is_computation_scheduled(computation)) {
656 HloInstructionSequence& clone_sequence =
657 clone_schedule.GetOrCreateSequence(
658 context.GetComputation(computation));
659 for (const HloInstruction* instruction :
660 schedule().sequence(computation).instructions()) {
661 clone_sequence.push_back(context.GetInstruction(instruction));
662 }
663 }
664 }
665 TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
666 }
667 return module;
668 }
669
RemoveUnusedComputations()670 Status HloModule::RemoveUnusedComputations() {
671 std::string suffix = "tmp";
672 auto module = absl::make_unique<HloModule>(
673 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
674 HloCloneContext context(module.get(), suffix);
675 entry_computation_->Clone(suffix, &context);
676 std::vector<HloComputation*> to_remove;
677 for (auto computation : computations()) {
678 auto found_computation = context.FindComputation(computation);
679 if (found_computation == nullptr) {
680 to_remove.push_back(computation);
681 }
682 }
683 for (auto computation : to_remove) {
684 TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
685 }
686 return Status::OK();
687 }
688
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)689 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
690 HloCloneContext* context) {
691 HloComputation* new_computation;
692 if (context != nullptr) {
693 if ((new_computation = context->FindComputation(computation)) != nullptr) {
694 return new_computation;
695 }
696 new_computation =
697 AddEmbeddedComputation(computation->Clone(context->suffix(), context));
698 } else {
699 new_computation = AddEmbeddedComputation(computation->Clone(""));
700 }
701 return new_computation;
702 }
703
RandomNew64() const704 uint64 HloModule::RandomNew64() const {
705 tensorflow::mutex_lock l(rng_mutex_);
706 return rng_();
707 }
708
GetComputationWithName(absl::string_view name)709 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
710 auto computations_in_module = computations();
711 auto it = absl::c_find_if(
712 computations_in_module,
713 [&](HloComputation* computation) { return computation->name() == name; });
714 return it == computations_in_module.end() ? nullptr : *it;
715 }
716
Hash() const717 uint64 HloModule::Hash() const {
718 return tensorflow::Hash64Combine(
719 entry_computation_layout().Hash(),
720 entry_computation()->root_instruction()->Hash());
721 }
722
723 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
724
725 } // namespace xla
726