1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <set>
21 #include <sstream>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/stacktrace.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace xla {
45
HloModule(const string & name,HloModuleConfig config)46 HloModule::HloModule(const string& name, HloModuleConfig config)
47 : name_(NameUniquer::GetSanitizedName(name)),
48 config_(std::move(config)),
49 unique_id_(next_unique_module_id_++),
50 metadata_(tensorflow::Env::Default()) {
51 metadata_.set_canonical_module_id(unique_id_);
52 }
53
set_schedule(HloSchedule schedule)54 Status HloModule::set_schedule(HloSchedule schedule) {
55 TF_RET_CHECK(schedule.module() == this);
56 TF_RETURN_IF_ERROR(schedule.Verify());
57 schedule_ = std::move(schedule);
58 return Status::OK();
59 }
60
ReplaceEntryComputation(HloComputation * entry_computation)61 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
62 entry_computation_ = entry_computation;
63 config_.SetDefaultComputationLayout(
64 entry_computation_->ComputeProgramShape());
65 input_output_alias_config_ = HloInputOutputAliasConfig(
66 entry_computation_->root_instruction()->shape());
67 }
68
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers,bool preserve_entry_layouts)69 HloComputation* HloModule::AddComputationInternal(
70 std::unique_ptr<HloComputation> computation, bool is_entry,
71 bool uniquify_identifiers, bool preserve_entry_layouts) {
72 if (is_entry) {
73 CHECK_EQ(nullptr, entry_computation_);
74 entry_computation_ = computation.get();
75
76 if (preserve_entry_layouts) {
77 config_.SetComputationLayoutIfExists(
78 entry_computation_->ComputeProgramShape());
79 } else if (!config_.has_entry_computation_layout()) {
80 // If the module configuration has no entry layout computation set, create
81 // a default one based on the program shape.
82 config_.SetDefaultComputationLayout(
83 entry_computation_->ComputeProgramShape());
84 }
85 input_output_alias_config_ = HloInputOutputAliasConfig(
86 entry_computation_->root_instruction()->shape());
87 }
88
89 if (uniquify_identifiers) {
90 computation->UniquifyName(&computation_name_uniquer_);
91 for (auto* instruction : computation->instructions()) {
92 instruction->UniquifyName(&instruction_name_uniquer_);
93 }
94
95 // Pick unique IDs for each instruction.
96 for (auto* instruction : computation->instructions()) {
97 instruction->SetUniqueId(NewUniqueInstructionId());
98 }
99 // Set unique id to this computation.
100 CHECK_NE(computation->root_instruction()->unique_id(), -1)
101 << "Root has no valid id: " << computation->ToString();
102 computation->SetUniqueId(computation->root_instruction()->unique_id());
103 } else {
104 // Don't uniquify the names of the computation or instruction, but we must
105 // run the names through the uniquifiers to prevent future name collisions
106 // for computations and instructions created later. Also, set the
107 // next_unique_id_ to the one greater than the max unique id of any
108 // instruction (or the computation) to avoid ID collisions.
109 computation_name_uniquer_.GetUniqueName(computation->name());
110 for (auto* instruction : computation->instructions()) {
111 instruction_name_uniquer_.GetUniqueName(instruction->name());
112 next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
113 }
114 if (next_unique_id_ < computation->unique_id() + 1) {
115 next_unique_id_ = computation->unique_id() + 1;
116 }
117 }
118
119 computation->set_parent(this);
120 computations_.push_back(std::move(computation));
121 return computations_.back().get();
122 }
123
AddEntryComputation(std::unique_ptr<HloComputation> computation)124 HloComputation* HloModule::AddEntryComputation(
125 std::unique_ptr<HloComputation> computation) {
126 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
127 /*uniquify_identifiers=*/true,
128 /*preserve_entry_layouts=*/false);
129 }
130
AddEntryComputationWithLayouts(std::unique_ptr<HloComputation> computation)131 HloComputation* HloModule::AddEntryComputationWithLayouts(
132 std::unique_ptr<HloComputation> computation) {
133 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
134 /*uniquify_identifiers=*/true,
135 /*preserve_entry_layouts=*/true);
136 }
137
RemoveEmbeddedComputation(HloComputation * to_remove)138 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
139 if (has_schedule() && !to_remove->IsCalledComputation()) {
140 schedule_->remove_computation(to_remove);
141 }
142
143 auto it = absl::c_find_if(
144 computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
145 return comp.get() == to_remove;
146 });
147 TF_RET_CHECK(it != computations_.end());
148 TF_RET_CHECK(it->get() == to_remove);
149 computations_.erase(it);
150 return Status::OK();
151 }
152
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)153 HloComputation* HloModule::AddEmbeddedComputation(
154 std::unique_ptr<HloComputation> computation) {
155 return AddComputationInternal(std::move(computation), /*is_entry=*/false,
156 /*uniquify_identifiers=*/true,
157 /*preserve_entry_layouts=*/false);
158 }
159
ReplaceComputations(const absl::flat_hash_map<HloComputation *,HloComputation * > & replacements)160 void HloModule::ReplaceComputations(
161 const absl::flat_hash_map<HloComputation*, HloComputation*>& replacements) {
162 // Replace all uses of non-canonical computations with their
163 // representatives.
164 std::vector<std::unique_ptr<HloComputation>> new_computations;
165 new_computations.reserve(computations_.size());
166
167 for (std::unique_ptr<HloComputation>& computation : computations_) {
168 for (auto* instruction : computation->instructions()) {
169 switch (instruction->opcode()) {
170 case HloOpcode::kAllReduce:
171 case HloOpcode::kCall:
172 case HloOpcode::kMap:
173 case HloOpcode::kReduce:
174 case HloOpcode::kReduceScatter:
175 case HloOpcode::kReduceWindow:
176 case HloOpcode::kScatter:
177 case HloOpcode::kSort: {
178 HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
179 replacements, instruction->to_apply(), nullptr);
180 if (new_arg != nullptr) {
181 instruction->set_to_apply(new_arg);
182 }
183 break;
184 }
185 case HloOpcode::kWhile: {
186 HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
187 replacements, instruction->while_condition(), nullptr);
188 if (new_condition != nullptr) {
189 instruction->set_while_condition(new_condition);
190 }
191 HloComputation* new_body = tensorflow::gtl::FindWithDefault(
192 replacements, instruction->while_body(), nullptr);
193 if (new_body != nullptr) {
194 instruction->set_while_body(new_body);
195 }
196 break;
197 }
198 case HloOpcode::kConditional: {
199 for (int b = 0; b < instruction->branch_count(); ++b) {
200 HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
201 replacements, instruction->branch_computation(b), nullptr);
202 if (new_computation != nullptr) {
203 instruction->set_branch_computation(b, new_computation);
204 }
205 }
206 break;
207 }
208 case HloOpcode::kSelectAndScatter: {
209 HloComputation* new_select = tensorflow::gtl::FindWithDefault(
210 replacements, instruction->select(), nullptr);
211 if (new_select != nullptr) {
212 instruction->set_select(new_select);
213 }
214 HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
215 replacements, instruction->scatter(), nullptr);
216 if (new_scatter != nullptr) {
217 instruction->set_scatter(new_scatter);
218 }
219 break;
220 }
221 default:
222 break;
223 }
224 }
225
226 if (replacements.find(computation.get()) == replacements.end()) {
227 new_computations.push_back(std::move(computation));
228 }
229 }
230
231 // Replace entry_computation if necessary.
232 entry_computation_ = tensorflow::gtl::FindWithDefault(
233 replacements, entry_computation_, entry_computation_);
234
235 computations_ = std::move(new_computations);
236 }
237
ToString(const HloPrintOptions & options) const238 string HloModule::ToString(const HloPrintOptions& options) const {
239 std::ostringstream s;
240 // When print_ids() is false, exclude module's name because it includes and
241 // leads to non-deterministic fingerprint.
242 s << "HloModule "
243 << (options.print_ids() ? PrintName(name(), options.print_ids()) : "");
244 if (has_schedule()) {
245 TF_CHECK_OK(schedule().Verify());
246 s << ", is_scheduled=true";
247 }
248 std::string serialized_aliasing = input_output_alias_config().ToShortString();
249 if (!serialized_aliasing.empty()) {
250 s << absl::StrFormat(", input_output_alias={ %s }", serialized_aliasing);
251 }
252 if (config_.alias_passthrough_params()) {
253 s << ", alias_passthrough_params=true";
254 }
255 s << "\n\n";
256 const auto& computations = options.canonicalize_computations()
257 ? MakeComputationSorted()
258 : MakeComputationPostOrder();
259 for (const HloComputation* computation : computations) {
260 if (!options.print_computation(computation)) {
261 continue;
262 }
263 if (computation == entry_computation()) {
264 s << "ENTRY ";
265 }
266 if (has_schedule() && schedule().is_computation_scheduled(computation)) {
267 s << computation->ToString(
268 options, schedule().sequence(computation).instructions())
269 << "\n\n";
270 } else {
271 s << computation->ToString(options) << "\n\n";
272 }
273 }
274 return s.str();
275 }
276
ToProto() const277 HloModuleProto HloModule::ToProto() const {
278 HloModuleProto proto;
279 proto.set_id(unique_id_);
280 proto.set_name(name_);
281 proto.set_entry_computation_name(entry_computation_->name());
282 proto.set_entry_computation_id(entry_computation_->unique_id());
283 for (const HloComputation* computation : MakeComputationPostOrder()) {
284 HloComputationProto computation_proto = computation->ToProto();
285 proto.add_computations()->Swap(&computation_proto);
286 }
287 if (has_schedule()) {
288 *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
289 }
290 *proto.mutable_host_program_shape() =
291 entry_computation_layout().ComputeProgramShape().ToProto();
292 *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
293 *proto.mutable_dynamic_parameter_binding() =
294 dynamic_parameter_binding().ToProto();
295 for (const auto& parameter_indices : CrossProgramPrefetches()) {
296 const auto& parameter = parameter_indices.first;
297 const auto& indices = parameter_indices.second;
298 auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
299 prefetch->set_parameter(parameter);
300 for (auto index : indices) {
301 prefetch->add_index(index);
302 }
303 }
304 proto.set_is_dynamic(is_dynamic_);
305 return proto;
306 }
307
CheckUniqueNamesAndIdsForComputationsAndInstructions() const308 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
309 absl::flat_hash_set<string> computation_names;
310 absl::flat_hash_set<int> computation_ids;
311 absl::flat_hash_set<string> instruction_names;
312 absl::flat_hash_set<int> instruction_ids;
313
314 for (const HloComputation* computation : computations()) {
315 TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
316 << "Computation name is not unique: " << computation->name();
317 computation_names.insert(computation->name());
318
319 TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
320 << "Computation id is not unique: " << computation->unique_id();
321 computation_ids.insert(computation->unique_id());
322
323 for (const HloInstruction* instruction : computation->instructions()) {
324 TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
325 << "Instruction name is not unique: " << instruction->name();
326 instruction_names.insert(instruction->name());
327
328 TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
329 << "Instruction id is not unique: " << instruction->unique_id();
330 instruction_ids.insert(instruction->unique_id());
331 }
332 }
333 return Status::OK();
334 }
335
336 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)337 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
338 const HloModuleProto& proto, const HloModuleConfig& module_config,
339 bool prohibit_empty_literal) {
340 VLOG(2) << "CreateFromProto()";
341 XLA_VLOG_LINES(3, proto.DebugString());
342
343 // The ProgramShape in the passed in module config must match the shapes of
344 // the entry parameters and root.
345 TF_RET_CHECK(proto.has_host_program_shape())
346 << "No program shape found in the proto";
347 ProgramShape expected_program_shape(proto.host_program_shape());
348 TF_RET_CHECK(expected_program_shape.parameters_size() ==
349 module_config.entry_computation_layout().parameter_count());
350 for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
351 const Shape& parameter_shape =
352 module_config.entry_computation_layout().parameter_layout(i).shape();
353 TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
354 parameter_shape))
355 << "HloModuleConfig has different shape for parameter " << i
356 << " than the HLO module. Expected: "
357 << ShapeUtil::HumanStringWithLayout(
358 expected_program_shape.parameters(i))
359 << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
360 }
361 const Shape& result_shape =
362 module_config.entry_computation_layout().result_layout().shape();
363 TF_RET_CHECK(
364 ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
365 << "HloModuleConfig has different result shape than the HLO module. "
366 "Expected: "
367 << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
368 << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
369
370 absl::flat_hash_map<int64, HloComputation*> computation_map;
371 absl::flat_hash_map<HloComputation*, int64> to_proto_id;
372 std::vector<std::unique_ptr<HloComputation>> computations;
373 HloComputation* entry = nullptr;
374 for (const HloComputationProto& computation_proto : proto.computations()) {
375 TF_ASSIGN_OR_RETURN(
376 std::unique_ptr<HloComputation> computation,
377 HloComputation::CreateFromProto(computation_proto, computation_map,
378 prohibit_empty_literal));
379 CHECK_NE(computation.get(), nullptr);
380 int64_t computation_id = computation_proto.id();
381 TF_RET_CHECK(computation_id != -1);
382 TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
383 computation_map[computation_id] = computation.get();
384 to_proto_id[computation.get()] = computation_id;
385 if (computation_id == proto.entry_computation_id()) {
386 entry = computation.get();
387 }
388 computations.push_back(std::move(computation));
389 }
390 TF_RET_CHECK(entry != nullptr);
391
392 auto module = absl::make_unique<HloModule>(proto.name(), module_config);
393
394 // Sort the computations in the proto id's order.
395 absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
396 const std::unique_ptr<HloComputation>& b) {
397 return to_proto_id[a.get()] < to_proto_id[b.get()];
398 });
399
400 // Add sorted computations to the module.
401 for (auto& computation : computations) {
402 bool is_entry = computation.get() == entry;
403 // Don't uniquify names because we want names to be stable across
404 // serialization and deserialization.
405 module->AddComputationInternal(std::move(computation), is_entry,
406 /*uniquify_identifiers=*/false,
407 /*preserve_entry_layouts=*/false);
408 }
409 TF_RET_CHECK(module->entry_computation_ != nullptr);
410
411 TF_ASSIGN_OR_RETURN(
412 module->input_output_alias_config_,
413 HloInputOutputAliasConfig::CreateFromProto(
414 entry->ComputeProgramShape().result(), proto.input_output_alias()));
415
416 // Because we didn't uniquify the names or the ids, double-check that the
417 // instruction and computation names and ids are unique from the proto.
418 TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
419 DynamicParameterBinding::CreateFromProto(
420 proto.dynamic_parameter_binding()));
421
422 TF_RETURN_IF_ERROR(
423 module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
424
425 if (proto.has_schedule()) {
426 TF_ASSIGN_OR_RETURN(
427 HloSchedule schedule,
428 HloSchedule::CreateFromProto(module.get(), proto.schedule()));
429 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
430 }
431
432 for (auto prefetch : proto.cross_program_prefetches()) {
433 module->AddCrossProgramPrefetch(
434 prefetch.parameter(),
435 ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
436 }
437
438 module->set_is_dynamic(proto.is_dynamic());
439
440 return std::move(module);
441 }
442
443 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)444 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
445 const ProgramShape& program_shape, const DebugOptions& debug_options,
446 const ExecutionOptions* execution_options) {
447 HloModuleConfig module_config(ProgramShape{program_shape});
448 module_config.set_debug_options(debug_options);
449 if (execution_options) {
450 if (execution_options->num_replicas() > 0) {
451 module_config.set_replica_count(execution_options->num_replicas());
452 }
453 if (execution_options->num_partitions() > 0) {
454 module_config.set_num_partitions(execution_options->num_partitions());
455 }
456 module_config.set_use_spmd_partitioning(
457 execution_options->use_spmd_partitioning());
458 module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
459 if (execution_options->has_device_assignment()) {
460 TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
461 DeviceAssignment::Deserialize(
462 execution_options->device_assignment()));
463 module_config.set_static_device_assignment(*device_assignment);
464 if (execution_options->num_replicas() > 0) {
465 CHECK_EQ(module_config.static_device_assignment().replica_count(),
466 module_config.replica_count());
467 }
468 if (execution_options->num_partitions() > 0) {
469 CHECK_EQ(module_config.static_device_assignment().computation_count(),
470 module_config.num_partitions());
471 }
472 }
473 }
474
475 // The module config is constructed with default layouts regardless of what is
476 // passed in via the ProgramShape. Set the layouts to the appropriate values.
477 ComputationLayout* entry_layout =
478 module_config.mutable_entry_computation_layout();
479 for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) {
480 TF_RETURN_IF_ERROR(
481 entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
482 program_shape.parameters(i)));
483 }
484 TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
485 program_shape.result()));
486 return module_config;
487 }
488
489 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)490 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
491 const HloModuleProto& module, const DebugOptions& debug_options,
492 const ExecutionOptions* execution_options) {
493 TF_RET_CHECK(module.has_host_program_shape())
494 << "No program shape found in the proto";
495 ProgramShape program_shape(module.host_program_shape());
496 return CreateModuleConfigFromShape(program_shape, debug_options,
497 execution_options);
498 }
499
500 namespace {
501 // Returns whether `hlo` is used outside the given subcomputation.
502 // `instructions_in_subcomputation` is the instruction set of the given
503 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)504 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
505 const absl::flat_hash_set<HloInstruction*>&
506 instructions_in_subcomputation) {
507 return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
508 return !instructions_in_subcomputation.contains(user);
509 });
510 }
511 } // anonymous namespace
512
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)513 HloInstruction* HloModule::OutlineExpressionFromComputation(
514 absl::Span<HloInstruction* const> instructions_to_outline,
515 const string& outlined_computation_name, HloComputation* computation) {
516 auto builder = HloComputation::Builder(outlined_computation_name);
517
518 // A map from original instructions to their counterparts in the new outlined
519 // function.
520 absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
521 // A set that contains all instructions to be outlined.
522 absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
523 instructions_to_outline.begin(), instructions_to_outline.end());
524 std::vector<HloInstruction*> arguments;
525 std::vector<HloInstruction*> outputs;
526 int64_t parameter_count = 0;
527 for (HloInstruction* instruction_to_outline : instructions_to_outline) {
528 // Clone the original instruction.
529 HloInstruction* outlined_instruction =
530 builder.AddInstruction(instruction_to_outline->Clone());
531
532 // Replace its operands to their counterparts in the new function.
533 for (int64_t operand_num = 0;
534 operand_num < outlined_instruction->operand_count(); ++operand_num) {
535 HloInstruction* old_operand =
536 outlined_instruction->mutable_operand(operand_num);
537
538 HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
539 if (*operand_slot == nullptr) {
540 // Because instructions_to_outline is in topological order, if
541 // old_operand is not in outlined_instructions, old_operand must be an
542 // input of the outlined subcomputation and thus should be represented
543 // as a parameter in the new function.
544 arguments.push_back(old_operand);
545 *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
546 parameter_count, old_operand->shape(), "p"));
547 ++parameter_count;
548 }
549 TF_CHECK_OK(
550 outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
551 }
552
553 // Insert the new instruction into the outlined_instructions map.
554 InsertOrDie(&outlined_instructions, instruction_to_outline,
555 outlined_instruction);
556
557 // Mark instruction_to_outline an output if it is used outside the
558 // subcomputation or is the output of the original computation (i.e. used
559 // externally).
560 if (instruction_to_outline->user_count() == 0 ||
561 IsUsedOutsideSubcomputation(*instruction_to_outline,
562 instruction_set_to_outline)) {
563 outputs.push_back(instruction_to_outline);
564 }
565 }
566
567 if (outputs.size() != 1) {
568 string error_message =
569 "The subcomputation to outline has multiple outputs:\n";
570 for (HloInstruction* output : outputs) {
571 absl::StrAppend(&error_message, output->ToString(), "\n");
572 }
573 LOG(FATAL) << error_message;
574 }
575 HloInstruction* output = outputs[0];
576
577 // Creates a call to the nested computation.
578 HloComputation* nested_computation = AddEmbeddedComputation(
579 builder.Build(FindOrDie(outlined_instructions, output)));
580 HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
581 output->shape(), arguments, nested_computation));
582
583 VLOG(2) << "Outlining the following instructions";
584 for (auto* instruction_to_outline : instructions_to_outline) {
585 VLOG(2) << " " << instruction_to_outline->ToString();
586 }
587 VLOG(2) << "as a call " << call->ToString();
588 VLOG(2) << "to " << nested_computation->ToString();
589
590 TF_CHECK_OK(output->ReplaceAllUsesWith(call));
591 for (auto i = instructions_to_outline.rbegin();
592 i != instructions_to_outline.rend(); ++i) {
593 TF_CHECK_OK(computation->RemoveInstruction(*i));
594 }
595
596 return call;
597 }
598
instruction_count() const599 int64 HloModule::instruction_count() const {
600 int64_t n = 0;
601 for (const auto& computation : computations_) {
602 n += computation->instruction_count();
603 }
604 return n;
605 }
606
MakeComputationPostOrder(const absl::flat_hash_set<HloComputation * > & allow_list) const607 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
608 const absl::flat_hash_set<HloComputation*>& allow_list) const {
609 std::vector<HloComputation*> filtered_post_order(allow_list.size());
610 auto post_order = this->MakeComputationPostOrder();
611
612 int filtered_idx = 0;
613 for (auto& computation : post_order) {
614 if (allow_list.contains(computation)) {
615 filtered_post_order[filtered_idx] = computation;
616 filtered_idx += 1;
617 }
618 }
619
620 return filtered_post_order;
621 }
622
MakeComputationPostOrder() const623 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
624 // First determine all root computations by building a set of nonroot
625 // computations (computations which are called by an instruction in the
626 // module).
627 absl::flat_hash_set<HloComputation*> nonroot_computations;
628 for (auto& computation : computations_) {
629 for (auto* instruction : computation->instructions()) {
630 for (HloComputation* called_computation :
631 instruction->called_computations()) {
632 nonroot_computations.insert(called_computation);
633 }
634 }
635 }
636
637 // Keep track of computations which have already been added to the post
638 // order. This prevents duplication as an embedded computation may be called
639 // from two different root computations.
640 absl::flat_hash_set<HloComputation*> added_computations;
641 std::vector<HloComputation*> post_order;
642 for (auto& computation : computations_) {
643 if (!nonroot_computations.contains(computation.get())) {
644 for (HloComputation* embedded_computation :
645 computation->MakeEmbeddedComputationsList()) {
646 if (!added_computations.contains(embedded_computation)) {
647 post_order.push_back(embedded_computation);
648 added_computations.insert(embedded_computation);
649 }
650 }
651 // Root computations should only be encountered once.
652 CHECK(!added_computations.contains(computation.get()));
653 post_order.push_back(computation.get());
654 added_computations.insert(computation.get());
655 }
656 }
657 if (post_order.size() != computations_.size()) {
658 for (HloComputation* computation : post_order) {
659 LOG(ERROR) << "Post Order: " << computation->name() << " ("
660 << computation->parent()->name() << ")";
661 }
662 for (auto& computation : computations_) {
663 LOG(ERROR) << "Computations: " << computation->name() << " ("
664 << computation->parent()->name() << ")";
665 }
666 LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
667 << " computation_count=" << computations_.size();
668 }
669 return post_order;
670 }
671
672 namespace {
CompareComputationsByContent(const HloComputation * a,const HloComputation * b)673 bool CompareComputationsByContent(const HloComputation* a,
674 const HloComputation* b) {
675 if (a->instruction_count() != b->instruction_count()) {
676 return a->instruction_count() < b->instruction_count();
677 }
678 return a->ToString(HloPrintOptions::Fingerprint()) <
679 b->ToString(HloPrintOptions::Fingerprint());
680 }
681
SortComputationsByContent(std::vector<HloComputation * > * computations)682 void SortComputationsByContent(std::vector<HloComputation*>* computations) {
683 absl::c_sort(*computations, CompareComputationsByContent);
684 }
685
686 } // anonymous namespace
687
MakeComputationSorted() const688 std::vector<HloComputation*> HloModule::MakeComputationSorted() const {
689 std::vector<HloComputation*> result = MakeComputationPostOrder();
690 if (config().content_aware_computation_sorting()) {
691 SortComputationsByContent(&result);
692 }
693 return result;
694 }
695
MakeNonfusionComputations() const696 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
697 std::vector<HloComputation*> result = MakeComputationPostOrder();
698 result.erase(std::remove_if(
699 result.begin(), result.end(),
700 [](HloComputation* c) { return c->IsFusionComputation(); }),
701 result.end());
702 return result;
703 }
704
MakeNonfusionComputationsSorted() const705 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
706 const {
707 auto result = MakeNonfusionComputations();
708 if (config().content_aware_computation_sorting()) {
709 SortComputationsByContent(&result);
710 }
711 return result;
712 }
713
Clone(const string & suffix) const714 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
715 return Clone(config(), suffix);
716 }
717
Clone(const HloModuleConfig & config,const string & suffix) const718 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
719 const string& suffix) const {
720 VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
721 auto module = absl::make_unique<HloModule>(
722 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
723
724 HloCloneContext context(module.get(), suffix);
725 auto cloned_computation = entry_computation_->Clone(suffix, &context);
726 module->AddEntryComputation(std::move(cloned_computation));
727 module->input_output_alias_config() = input_output_alias_config();
728 module->set_is_dynamic(is_dynamic());
729 if (has_schedule() && schedule().Verify().ok()) {
730 HloSchedule clone_schedule(module.get());
731 for (HloComputation* computation : computations()) {
732 if (schedule().is_computation_scheduled(computation)) {
733 HloComputation* new_computation = context.FindComputation(computation);
734 // The module being cloned may have computations that are dead, i.e.,
735 // unreachable from the entry computation. In that case, new_computation
736 // is nullptr.
737 if (new_computation != nullptr) {
738 HloInstructionSequence& clone_sequence =
739 clone_schedule.GetOrCreateSequence(new_computation);
740 for (const HloInstruction* instruction :
741 schedule().sequence(computation).instructions()) {
742 clone_sequence.push_back(context.GetInstruction(instruction));
743 }
744 }
745 }
746 }
747 TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
748 }
749 for (const auto& parameter_indices : CrossProgramPrefetches()) {
750 const auto& parameter = parameter_indices.first;
751 const auto& indices = parameter_indices.second;
752 module->AddCrossProgramPrefetch(parameter, indices);
753 }
754 return module;
755 }
756
RemoveUnusedComputations()757 Status HloModule::RemoveUnusedComputations() {
758 std::string suffix = "tmp";
759 auto module = absl::make_unique<HloModule>(
760 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
761 HloCloneContext context(module.get(), suffix);
762 entry_computation_->Clone(suffix, &context);
763 std::vector<HloComputation*> to_remove;
764 for (auto computation : computations()) {
765 auto found_computation = context.FindComputation(computation);
766 if (found_computation == nullptr) {
767 to_remove.push_back(computation);
768 }
769 }
770 for (auto computation : to_remove) {
771 TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
772 }
773 return Status::OK();
774 }
775
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)776 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
777 HloCloneContext* context) {
778 HloComputation* new_computation;
779 if (context != nullptr) {
780 if ((new_computation = context->FindComputation(computation)) != nullptr) {
781 return new_computation;
782 }
783 new_computation =
784 AddEmbeddedComputation(computation->Clone(context->suffix(), context));
785 } else {
786 new_computation = AddEmbeddedComputation(computation->Clone(""));
787 }
788 return new_computation;
789 }
790
RandomNew64() const791 uint64 HloModule::RandomNew64() const {
792 tensorflow::mutex_lock l(rng_mutex_);
793 return rng_();
794 }
795
GetComputationWithName(absl::string_view name)796 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
797 auto computations_in_module = computations();
798 auto it = absl::c_find_if(
799 computations_in_module,
800 [&](HloComputation* computation) { return computation->name() == name; });
801 return it == computations_in_module.end() ? nullptr : *it;
802 }
803
Hash() const804 uint64 HloModule::Hash() const {
805 uint64 result = entry_computation_layout().Hash();
806 // Use MakeComputationSorted() instead of MakeComputationPostOrder()
807 // because naming may affect the order of MakeComputationPostOrder() but not
808 // MakeComputationSorted().
809 for (auto* computation : MakeComputationSorted()) {
810 for (auto* instruction : computation->MakeInstructionPostOrder()) {
811 result = tensorflow::Hash64Combine(result, instruction->Hash());
812 }
813 }
814 return result;
815 }
816
817 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
818
819 } // namespace xla
820