1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <set>
21 #include <sstream>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/platform/stacktrace.h"
40 #include "tensorflow/core/platform/types.h"
41
42 namespace xla {
43
HloModule(const string & name,HloModuleConfig config)44 HloModule::HloModule(const string& name, HloModuleConfig config)
45 : name_(NameUniquer::GetSanitizedName(name)),
46 config_(std::move(config)),
47 unique_id_(next_unique_module_id_++),
48 metadata_(tensorflow::Env::Default()) {
49 metadata_.set_canonical_module_id(unique_id_);
50 }
51
set_schedule(HloSchedule schedule)52 Status HloModule::set_schedule(HloSchedule schedule) {
53 TF_RET_CHECK(schedule.module() == this);
54 TF_RETURN_IF_ERROR(schedule.Verify());
55 schedule_ = std::move(schedule);
56 return Status::OK();
57 }
58
ReplaceEntryComputation(HloComputation * entry_computation)59 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
60 entry_computation_ = entry_computation;
61 config_.SetDefaultComputationLayout(
62 entry_computation_->ComputeProgramShape());
63 input_output_alias_config_ = HloInputOutputAliasConfig(
64 entry_computation_->root_instruction()->shape());
65 }
66
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers,bool preserve_entry_layouts)67 HloComputation* HloModule::AddComputationInternal(
68 std::unique_ptr<HloComputation> computation, bool is_entry,
69 bool uniquify_identifiers, bool preserve_entry_layouts) {
70 if (is_entry) {
71 CHECK_EQ(nullptr, entry_computation_);
72 entry_computation_ = computation.get();
73
74 if (preserve_entry_layouts) {
75 config_.SetComputationLayoutIfExists(
76 entry_computation_->ComputeProgramShape());
77 } else if (!config_.has_entry_computation_layout()) {
78 // If the module configuration has no entry layout computation set, create
79 // a default one based on the program shape.
80 config_.SetDefaultComputationLayout(
81 entry_computation_->ComputeProgramShape());
82 }
83 input_output_alias_config_ = HloInputOutputAliasConfig(
84 entry_computation_->root_instruction()->shape());
85 }
86
87 if (uniquify_identifiers) {
88 computation->UniquifyName(&computation_name_uniquer_);
89 for (auto* instruction : computation->instructions()) {
90 instruction->UniquifyName(&instruction_name_uniquer_);
91 }
92
93 // Pick unique IDs for each instruction.
94 for (auto* instruction : computation->instructions()) {
95 instruction->SetUniqueId(NewUniqueInstructionId());
96 }
97 // Set unique id to this computation.
98 CHECK_NE(computation->root_instruction()->unique_id(), -1)
99 << "Root has no valid id: " << computation->ToString();
100 computation->SetUniqueId(computation->root_instruction()->unique_id());
101 } else {
102 // Don't uniquify the names of the computation or instruction, but we must
103 // run the names through the uniquifiers to prevent future name collisions
104 // for computations and instructions created later. Also, set the
105 // next_unique_id_ to the one greater than the max unique id of any
106 // instruction (or the computation) to avoid ID collisions.
107 computation_name_uniquer_.GetUniqueName(computation->name());
108 for (auto* instruction : computation->instructions()) {
109 instruction_name_uniquer_.GetUniqueName(instruction->name());
110 next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
111 }
112 if (next_unique_id_ < computation->unique_id() + 1) {
113 next_unique_id_ = computation->unique_id() + 1;
114 }
115 }
116
117 computation->set_parent(this);
118 computations_.push_back(std::move(computation));
119 return computations_.back().get();
120 }
121
AddEntryComputation(std::unique_ptr<HloComputation> computation)122 HloComputation* HloModule::AddEntryComputation(
123 std::unique_ptr<HloComputation> computation) {
124 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
125 /*uniquify_identifiers=*/true,
126 /*preserve_entry_layouts=*/false);
127 }
128
AddEntryComputationWithLayouts(std::unique_ptr<HloComputation> computation)129 HloComputation* HloModule::AddEntryComputationWithLayouts(
130 std::unique_ptr<HloComputation> computation) {
131 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
132 /*uniquify_identifiers=*/true,
133 /*preserve_entry_layouts=*/true);
134 }
135
RemoveEmbeddedComputation(HloComputation * to_remove)136 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
137 if (has_schedule() && !to_remove->IsFusionComputation()) {
138 schedule_->remove_computation(to_remove);
139 }
140
141 auto it = absl::c_find_if(
142 computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
143 return comp.get() == to_remove;
144 });
145 TF_RET_CHECK(it != computations_.end());
146 TF_RET_CHECK(it->get() == to_remove);
147 computations_.erase(it);
148 return Status::OK();
149 }
150
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)151 HloComputation* HloModule::AddEmbeddedComputation(
152 std::unique_ptr<HloComputation> computation) {
153 return AddComputationInternal(std::move(computation), /*is_entry=*/false,
154 /*uniquify_identifiers=*/true,
155 /*preserve_entry_layouts=*/false);
156 }
157
ReplaceComputations(const std::unordered_map<HloComputation *,HloComputation * > & replacements)158 void HloModule::ReplaceComputations(
159 const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
160 // Replace all uses of non-canonical computations with their
161 // representatives.
162 std::vector<std::unique_ptr<HloComputation>> new_computations;
163 new_computations.reserve(computations_.size());
164
165 for (std::unique_ptr<HloComputation>& computation : computations_) {
166 for (auto* instruction : computation->instructions()) {
167 switch (instruction->opcode()) {
168 case HloOpcode::kAllReduce:
169 case HloOpcode::kCall:
170 case HloOpcode::kMap:
171 case HloOpcode::kReduce:
172 case HloOpcode::kReduceWindow:
173 case HloOpcode::kScatter:
174 case HloOpcode::kSort: {
175 HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
176 replacements, instruction->to_apply(), nullptr);
177 if (new_arg != nullptr) {
178 instruction->set_to_apply(new_arg);
179 }
180 break;
181 }
182 case HloOpcode::kWhile: {
183 HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
184 replacements, instruction->while_condition(), nullptr);
185 if (new_condition != nullptr) {
186 instruction->set_while_condition(new_condition);
187 }
188 HloComputation* new_body = tensorflow::gtl::FindWithDefault(
189 replacements, instruction->while_body(), nullptr);
190 if (new_body != nullptr) {
191 instruction->set_while_body(new_body);
192 }
193 break;
194 }
195 case HloOpcode::kConditional: {
196 for (int b = 0; b < instruction->branch_count(); ++b) {
197 HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
198 replacements, instruction->branch_computation(b), nullptr);
199 if (new_computation != nullptr) {
200 instruction->set_branch_computation(b, new_computation);
201 }
202 }
203 break;
204 }
205 case HloOpcode::kSelectAndScatter: {
206 HloComputation* new_select = tensorflow::gtl::FindWithDefault(
207 replacements, instruction->select(), nullptr);
208 if (new_select != nullptr) {
209 instruction->set_select(new_select);
210 }
211 HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
212 replacements, instruction->scatter(), nullptr);
213 if (new_scatter != nullptr) {
214 instruction->set_scatter(new_scatter);
215 }
216 break;
217 }
218 default:
219 break;
220 }
221 }
222
223 if (replacements.find(computation.get()) == replacements.end()) {
224 new_computations.push_back(std::move(computation));
225 }
226 }
227
228 // Replace entry_computation if necessary.
229 entry_computation_ = tensorflow::gtl::FindWithDefault(
230 replacements, entry_computation_, entry_computation_);
231
232 computations_ = std::move(new_computations);
233 }
234
ToString(const HloPrintOptions & options) const235 string HloModule::ToString(const HloPrintOptions& options) const {
236 std::ostringstream s;
237 // When print_ids() is false, exclude module's name because it includes and
238 // leads to non-deterministic fingerprint.
239 s << "HloModule "
240 << (options.print_ids() ? PrintName(name(), options.print_ids()) : "");
241 if (has_schedule()) {
242 TF_CHECK_OK(schedule().Verify());
243 s << ", is_scheduled=true";
244 }
245 std::string serialized_aliasing = input_output_alias_config().ToShortString();
246 if (!serialized_aliasing.empty()) {
247 s << absl::StrFormat(", input_output_alias={ %s }", serialized_aliasing);
248 }
249 s << "\n\n";
250 const auto& computations = options.canonicalize_computations()
251 ? MakeComputationSorted()
252 : MakeComputationPostOrder();
253 for (const HloComputation* computation : computations) {
254 if (!options.print_computation(computation)) {
255 continue;
256 }
257 if (computation == entry_computation()) {
258 s << "ENTRY ";
259 }
260 if (has_schedule() && schedule().is_computation_scheduled(computation)) {
261 s << computation->ToString(
262 options, schedule().sequence(computation).instructions())
263 << "\n\n";
264 } else {
265 s << computation->ToString(options) << "\n\n";
266 }
267 }
268 return s.str();
269 }
270
ToProto() const271 HloModuleProto HloModule::ToProto() const {
272 HloModuleProto proto;
273 proto.set_id(unique_id_);
274 proto.set_name(name_);
275 proto.set_entry_computation_name(entry_computation_->name());
276 proto.set_entry_computation_id(entry_computation_->unique_id());
277 for (const HloComputation* computation : MakeComputationPostOrder()) {
278 HloComputationProto computation_proto = computation->ToProto();
279 proto.add_computations()->Swap(&computation_proto);
280 }
281 if (has_schedule()) {
282 *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
283 }
284 *proto.mutable_host_program_shape() =
285 entry_computation_layout().ComputeProgramShape().ToProto();
286 *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
287 *proto.mutable_dynamic_parameter_binding() =
288 dynamic_parameter_binding().ToProto();
289 for (const auto& parameter_indices : CrossProgramPrefetches()) {
290 const auto& parameter = parameter_indices.first;
291 const auto& indices = parameter_indices.second;
292 auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
293 prefetch->set_parameter(parameter);
294 for (auto index : indices) {
295 prefetch->add_index(index);
296 }
297 }
298 return proto;
299 }
300
CheckUniqueNamesAndIdsForComputationsAndInstructions() const301 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
302 absl::flat_hash_set<string> computation_names;
303 absl::flat_hash_set<int> computation_ids;
304 absl::flat_hash_set<string> instruction_names;
305 absl::flat_hash_set<int> instruction_ids;
306
307 for (const HloComputation* computation : computations()) {
308 TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
309 << "Computation name is not unique: " << computation->name();
310 computation_names.insert(computation->name());
311
312 TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
313 << "Computation id is not unique: " << computation->unique_id();
314 computation_ids.insert(computation->unique_id());
315
316 for (const HloInstruction* instruction : computation->instructions()) {
317 TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
318 << "Instruction name is not unique: " << instruction->name();
319 instruction_names.insert(instruction->name());
320
321 TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
322 << "Instruction id is not unique: " << instruction->unique_id();
323 instruction_ids.insert(instruction->unique_id());
324 }
325 }
326 return Status::OK();
327 }
328
329 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)330 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
331 const HloModuleProto& proto, const HloModuleConfig& module_config,
332 bool prohibit_empty_literal) {
333 VLOG(2) << "CreateFromProto()";
334 XLA_VLOG_LINES(3, proto.DebugString());
335
336 // The ProgramShape in the passed in module config must match the shapes of
337 // the entry parameters and root.
338 TF_RET_CHECK(proto.has_host_program_shape())
339 << "No program shape found in the proto";
340 ProgramShape expected_program_shape(proto.host_program_shape());
341 TF_RET_CHECK(expected_program_shape.parameters_size() ==
342 module_config.entry_computation_layout().parameter_count());
343 for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
344 const Shape& parameter_shape =
345 module_config.entry_computation_layout().parameter_layout(i).shape();
346 TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
347 parameter_shape))
348 << "HloModuleConfig has different shape for parameter " << i
349 << " than the HLO module. Expected: "
350 << ShapeUtil::HumanStringWithLayout(
351 expected_program_shape.parameters(i))
352 << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
353 }
354 const Shape& result_shape =
355 module_config.entry_computation_layout().result_layout().shape();
356 TF_RET_CHECK(
357 ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
358 << "HloModuleConfig has different result shape than the HLO module. "
359 "Expected: "
360 << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
361 << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
362
363 absl::flat_hash_map<int64, HloComputation*> computation_map;
364 absl::flat_hash_map<HloComputation*, int64> to_proto_id;
365 std::vector<std::unique_ptr<HloComputation>> computations;
366 HloComputation* entry = nullptr;
367 for (const HloComputationProto& computation_proto : proto.computations()) {
368 TF_ASSIGN_OR_RETURN(
369 std::unique_ptr<HloComputation> computation,
370 HloComputation::CreateFromProto(computation_proto, computation_map,
371 prohibit_empty_literal));
372 CHECK_NE(computation.get(), nullptr);
373 int64 computation_id = computation_proto.id();
374 TF_RET_CHECK(computation_id != -1);
375 TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
376 computation_map[computation_id] = computation.get();
377 to_proto_id[computation.get()] = computation_id;
378 if (computation_id == proto.entry_computation_id()) {
379 entry = computation.get();
380 }
381 computations.push_back(std::move(computation));
382 }
383 TF_RET_CHECK(entry != nullptr);
384
385 auto module = absl::make_unique<HloModule>(proto.name(), module_config);
386
387 // Sort the computations in the proto id's order.
388 absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
389 const std::unique_ptr<HloComputation>& b) {
390 return to_proto_id[a.get()] < to_proto_id[b.get()];
391 });
392
393 // Add sorted computations to the module.
394 for (auto& computation : computations) {
395 bool is_entry = computation.get() == entry;
396 // Don't uniquify names because we want names to be stable across
397 // serialization and deserialization.
398 module->AddComputationInternal(std::move(computation), is_entry,
399 /*uniquify_identifiers=*/false,
400 /*preserve_entry_layouts=*/false);
401 }
402 TF_RET_CHECK(module->entry_computation_ != nullptr);
403
404 TF_ASSIGN_OR_RETURN(
405 module->input_output_alias_config_,
406 HloInputOutputAliasConfig::CreateFromProto(
407 entry->ComputeProgramShape().result(), proto.input_output_alias()));
408
409 // Because we didn't uniquify the names or the ids, double-check that the
410 // instruction and computation names and ids are unique from the proto.
411 TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
412 DynamicParameterBinding::CreateFromProto(
413 proto.dynamic_parameter_binding()));
414
415 TF_RETURN_IF_ERROR(
416 module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
417
418 if (proto.has_schedule()) {
419 TF_ASSIGN_OR_RETURN(
420 HloSchedule schedule,
421 HloSchedule::CreateFromProto(module.get(), proto.schedule()));
422 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
423 }
424
425 for (auto prefetch : proto.cross_program_prefetches()) {
426 module->AddCrossProgramPrefetch(
427 prefetch.parameter(),
428 ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
429 }
430
431 return std::move(module);
432 }
433
434 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)435 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
436 const ProgramShape& program_shape, const DebugOptions& debug_options,
437 const ExecutionOptions* execution_options) {
438 HloModuleConfig module_config(ProgramShape{program_shape});
439 module_config.set_debug_options(debug_options);
440 if (execution_options) {
441 if (execution_options->num_replicas() > 0) {
442 module_config.set_replica_count(execution_options->num_replicas());
443 }
444 if (execution_options->num_partitions() > 0) {
445 module_config.set_num_partitions(execution_options->num_partitions());
446 }
447 module_config.set_use_spmd_partitioning(
448 execution_options->use_spmd_partitioning());
449 module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
450 module_config.set_broadcast_replicated_params(
451 execution_options->broadcast_replicated_parameters_via_collectives());
452 if (execution_options->has_device_assignment()) {
453 TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
454 DeviceAssignment::Deserialize(
455 execution_options->device_assignment()));
456 module_config.set_static_device_assignment(*device_assignment);
457 if (execution_options->num_replicas() > 0) {
458 CHECK_EQ(module_config.static_device_assignment().replica_count(),
459 module_config.replica_count());
460 }
461 if (execution_options->num_partitions() > 0) {
462 CHECK_EQ(module_config.static_device_assignment().computation_count(),
463 module_config.num_partitions());
464 }
465 }
466 }
467
468 // The module config is constructed with default layouts regardless of what is
469 // passed in via the ProgramShape. Set the layouts to the appropriate values.
470 ComputationLayout* entry_layout =
471 module_config.mutable_entry_computation_layout();
472 for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
473 TF_RETURN_IF_ERROR(
474 entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
475 program_shape.parameters(i)));
476 }
477 TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
478 program_shape.result()));
479 return module_config;
480 }
481
482 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)483 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
484 const HloModuleProto& module, const DebugOptions& debug_options,
485 const ExecutionOptions* execution_options) {
486 TF_RET_CHECK(module.has_host_program_shape())
487 << "No program shape found in the proto";
488 ProgramShape program_shape(module.host_program_shape());
489 return CreateModuleConfigFromShape(program_shape, debug_options,
490 execution_options);
491 }
492
493 namespace {
494 // Returns whether `hlo` is used outside the given subcomputation.
495 // `instructions_in_subcomputation` is the instruction set of the given
496 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)497 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
498 const absl::flat_hash_set<HloInstruction*>&
499 instructions_in_subcomputation) {
500 return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
501 return !instructions_in_subcomputation.contains(user);
502 });
503 }
504 } // anonymous namespace
505
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)506 HloInstruction* HloModule::OutlineExpressionFromComputation(
507 absl::Span<HloInstruction* const> instructions_to_outline,
508 const string& outlined_computation_name, HloComputation* computation) {
509 auto builder = HloComputation::Builder(outlined_computation_name);
510
511 // A map from original instructions to their counterparts in the new outlined
512 // function.
513 absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
514 // A set that contains all instructions to be outlined.
515 absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
516 instructions_to_outline.begin(), instructions_to_outline.end());
517 std::vector<HloInstruction*> arguments;
518 std::vector<HloInstruction*> outputs;
519 int64 parameter_count = 0;
520 for (HloInstruction* instruction_to_outline : instructions_to_outline) {
521 // Clone the original instruction.
522 HloInstruction* outlined_instruction =
523 builder.AddInstruction(instruction_to_outline->Clone());
524
525 // Replace its operands to their counterparts in the new function.
526 for (int64 operand_num = 0;
527 operand_num < outlined_instruction->operand_count(); ++operand_num) {
528 HloInstruction* old_operand =
529 outlined_instruction->mutable_operand(operand_num);
530
531 HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
532 if (*operand_slot == nullptr) {
533 // Because instructions_to_outline is in topological order, if
534 // old_operand is not in outlined_instructions, old_operand must be an
535 // input of the outlined subcomputation and thus should be represented
536 // as a parameter in the new function.
537 arguments.push_back(old_operand);
538 *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
539 parameter_count, old_operand->shape(), "p"));
540 ++parameter_count;
541 }
542 TF_CHECK_OK(
543 outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
544 }
545
546 // Insert the new instruction into the outlined_instructions map.
547 InsertOrDie(&outlined_instructions, instruction_to_outline,
548 outlined_instruction);
549
550 // Mark instruction_to_outline an output if it is used outside the
551 // subcomputation or is the output of the original computation (i.e. used
552 // externally).
553 if (instruction_to_outline->user_count() == 0 ||
554 IsUsedOutsideSubcomputation(*instruction_to_outline,
555 instruction_set_to_outline)) {
556 outputs.push_back(instruction_to_outline);
557 }
558 }
559
560 if (outputs.size() != 1) {
561 string error_message =
562 "The subcomputation to outline has multiple outputs:\n";
563 for (HloInstruction* output : outputs) {
564 absl::StrAppend(&error_message, output->ToString(), "\n");
565 }
566 LOG(FATAL) << error_message;
567 }
568 HloInstruction* output = outputs[0];
569
570 // Creates a call to the nested computation.
571 HloComputation* nested_computation = AddEmbeddedComputation(
572 builder.Build(FindOrDie(outlined_instructions, output)));
573 HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
574 output->shape(), arguments, nested_computation));
575
576 VLOG(2) << "Outlining the following instructions";
577 for (auto* instruction_to_outline : instructions_to_outline) {
578 VLOG(2) << " " << instruction_to_outline->ToString();
579 }
580 VLOG(2) << "as a call " << call->ToString();
581 VLOG(2) << "to " << nested_computation->ToString();
582
583 TF_CHECK_OK(output->ReplaceAllUsesWith(call));
584 for (auto i = instructions_to_outline.rbegin();
585 i != instructions_to_outline.rend(); ++i) {
586 TF_CHECK_OK(computation->RemoveInstruction(*i));
587 }
588
589 return call;
590 }
591
instruction_count() const592 int64 HloModule::instruction_count() const {
593 int64 n = 0;
594 for (const auto& computation : computations_) {
595 n += computation->instruction_count();
596 }
597 return n;
598 }
599
MakeComputationPostOrder(const absl::flat_hash_set<HloComputation * > & allow_list) const600 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
601 const absl::flat_hash_set<HloComputation*>& allow_list) const {
602 std::vector<HloComputation*> filtered_post_order(allow_list.size());
603 auto post_order = this->MakeComputationPostOrder();
604
605 int filtered_idx = 0;
606 for (auto& computation : post_order) {
607 if (allow_list.contains(computation)) {
608 filtered_post_order[filtered_idx] = computation;
609 filtered_idx += 1;
610 }
611 }
612
613 return filtered_post_order;
614 }
615
MakeComputationPostOrder() const616 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
617 // First determine all root computations by building a set of nonroot
618 // computations (computations which are called by an instruction in the
619 // module).
620 absl::flat_hash_set<HloComputation*> nonroot_computations;
621 for (auto& computation : computations_) {
622 for (auto* instruction : computation->instructions()) {
623 for (HloComputation* called_computation :
624 instruction->called_computations()) {
625 nonroot_computations.insert(called_computation);
626 }
627 }
628 }
629
630 // Keep track of computations which have already been added to the post
631 // order. This prevents duplication as an embedded computation may be called
632 // from two different root computations.
633 absl::flat_hash_set<HloComputation*> added_computations;
634 std::vector<HloComputation*> post_order;
635 for (auto& computation : computations_) {
636 if (!nonroot_computations.contains(computation.get())) {
637 for (HloComputation* embedded_computation :
638 computation->MakeEmbeddedComputationsList()) {
639 if (!added_computations.contains(embedded_computation)) {
640 post_order.push_back(embedded_computation);
641 added_computations.insert(embedded_computation);
642 }
643 }
644 // Root computations should only be encountered once.
645 CHECK(!added_computations.contains(computation.get()));
646 post_order.push_back(computation.get());
647 added_computations.insert(computation.get());
648 }
649 }
650 if (post_order.size() != computations_.size()) {
651 for (HloComputation* computation : post_order) {
652 LOG(ERROR) << "Post Order: " << computation->name() << " ("
653 << computation->parent()->name() << ")";
654 }
655 for (auto& computation : computations_) {
656 LOG(ERROR) << "Computations: " << computation->name() << " ("
657 << computation->parent()->name() << ")";
658 }
659 LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
660 << " computation_count=" << computations_.size();
661 }
662 return post_order;
663 }
664
665 namespace {
CompareComputationsByContent(HloComputation * a,HloComputation * b)666 bool CompareComputationsByContent(HloComputation* a, HloComputation* b) {
667 if (a->instruction_count() != b->instruction_count()) {
668 return a->instruction_count() < b->instruction_count();
669 }
670 return a->ToString(HloPrintOptions::Fingerprint()) <
671 b->ToString(HloPrintOptions::Fingerprint());
672 }
673 } // anonymous namespace
674
MakeComputationSorted() const675 std::vector<HloComputation*> HloModule::MakeComputationSorted() const {
676 std::vector<HloComputation*> result = MakeComputationPostOrder();
677 if (config().content_aware_computation_sorting()) {
678 absl::c_sort(result, CompareComputationsByContent);
679 }
680 return result;
681 }
682
MakeNonfusionComputations() const683 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
684 std::vector<HloComputation*> result = MakeComputationPostOrder();
685 result.erase(std::remove_if(
686 result.begin(), result.end(),
687 [](HloComputation* c) { return c->IsFusionComputation(); }),
688 result.end());
689 return result;
690 }
691
MakeNonfusionComputationsSorted() const692 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
693 const {
694 auto result = MakeNonfusionComputations();
695 if (config().content_aware_computation_sorting()) {
696 absl::c_sort(result, CompareComputationsByContent);
697 }
698 return result;
699 }
700
Clone(const string & suffix) const701 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
702 return Clone(config(), suffix);
703 }
704
Clone(const HloModuleConfig & config,const string & suffix) const705 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
706 const string& suffix) const {
707 VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
708 auto module = absl::make_unique<HloModule>(
709 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
710
711 HloCloneContext context(module.get(), suffix);
712 auto cloned_computation = entry_computation_->Clone(suffix, &context);
713 module->AddEntryComputation(std::move(cloned_computation));
714 module->input_output_alias_config() = input_output_alias_config();
715
716 if (has_schedule() && schedule().Verify().ok()) {
717 HloSchedule clone_schedule(module.get());
718 for (HloComputation* computation : computations()) {
719 if (schedule().is_computation_scheduled(computation)) {
720 HloInstructionSequence& clone_sequence =
721 clone_schedule.GetOrCreateSequence(
722 context.GetComputation(computation));
723 for (const HloInstruction* instruction :
724 schedule().sequence(computation).instructions()) {
725 clone_sequence.push_back(context.GetInstruction(instruction));
726 }
727 }
728 }
729 TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
730 }
731 for (const auto& parameter_indices : CrossProgramPrefetches()) {
732 const auto& parameter = parameter_indices.first;
733 const auto& indices = parameter_indices.second;
734 module->AddCrossProgramPrefetch(parameter, indices);
735 }
736 return module;
737 }
738
RemoveUnusedComputations()739 Status HloModule::RemoveUnusedComputations() {
740 std::string suffix = "tmp";
741 auto module = absl::make_unique<HloModule>(
742 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
743 HloCloneContext context(module.get(), suffix);
744 entry_computation_->Clone(suffix, &context);
745 std::vector<HloComputation*> to_remove;
746 for (auto computation : computations()) {
747 auto found_computation = context.FindComputation(computation);
748 if (found_computation == nullptr) {
749 to_remove.push_back(computation);
750 }
751 }
752 for (auto computation : to_remove) {
753 TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
754 }
755 return Status::OK();
756 }
757
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)758 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
759 HloCloneContext* context) {
760 HloComputation* new_computation;
761 if (context != nullptr) {
762 if ((new_computation = context->FindComputation(computation)) != nullptr) {
763 return new_computation;
764 }
765 new_computation =
766 AddEmbeddedComputation(computation->Clone(context->suffix(), context));
767 } else {
768 new_computation = AddEmbeddedComputation(computation->Clone(""));
769 }
770 return new_computation;
771 }
772
RandomNew64() const773 uint64 HloModule::RandomNew64() const {
774 tensorflow::mutex_lock l(rng_mutex_);
775 return rng_();
776 }
777
GetComputationWithName(absl::string_view name)778 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
779 auto computations_in_module = computations();
780 auto it = absl::c_find_if(
781 computations_in_module,
782 [&](HloComputation* computation) { return computation->name() == name; });
783 return it == computations_in_module.end() ? nullptr : *it;
784 }
785
Hash() const786 uint64 HloModule::Hash() const {
787 uint64 result = entry_computation_layout().Hash();
788 // Use MakeComputationSorted() instead of MakeComputationPostOrder()
789 // because naming may affect the order of MakeComputationPostOrder() but not
790 // MakeComputationSorted().
791 for (auto* computation : MakeComputationSorted()) {
792 for (auto* instruction : computation->MakeInstructionPostOrder()) {
793 result = tensorflow::Hash64Combine(result, instruction->Hash());
794 }
795 }
796 return result;
797 }
798
799 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
800
801 } // namespace xla
802