1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <optional>
24 #include <set>
25 #include <sstream>
26 #include <string>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/algorithm/container.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/memory/memory.h"
34 #include "absl/strings/str_cat.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/service/compilation_environments.h"
37 #include "tensorflow/compiler/xla/service/computation_placer.h"
38 #include "tensorflow/compiler/xla/service/hlo.pb.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
42 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/fingerprint.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/status.h"
53 #include "tensorflow/core/platform/statusor.h"
54
55 namespace xla {
56
HloModule(const std::string & name,HloModuleConfig config)57 HloModule::HloModule(const std::string& name, HloModuleConfig config)
58 : HloModule(name, config, std::make_unique<CompilationEnvironments>()) {}
59
set_schedule(HloSchedule schedule)60 Status HloModule::set_schedule(HloSchedule schedule) {
61 TF_RET_CHECK(schedule.module() == this);
62 TF_RETURN_IF_ERROR(schedule.Verify());
63 schedule_ = std::move(schedule);
64 return OkStatus();
65 }
66
ReplaceEntryComputation(HloComputation * entry_computation)67 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
68 entry_computation_ = entry_computation;
69 config_.SetDefaultComputationLayout(
70 entry_computation_->ComputeProgramShape());
71 input_output_alias_config_ = HloInputOutputAliasConfig(
72 entry_computation_->root_instruction()->shape());
73 }
74
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers,bool preserve_entry_layouts)75 HloComputation* HloModule::AddComputationInternal(
76 std::unique_ptr<HloComputation> computation, bool is_entry,
77 bool uniquify_identifiers, bool preserve_entry_layouts) {
78 if (is_entry) {
79 CHECK_EQ(nullptr, entry_computation_);
80 entry_computation_ = computation.get();
81
82 if (preserve_entry_layouts) {
83 config_.SetComputationLayoutIfExists(
84 entry_computation_->ComputeProgramShape());
85 } else if (!config_.has_entry_computation_layout()) {
86 // If the module configuration has no entry layout computation set, create
87 // a default one based on the program shape.
88 config_.SetDefaultComputationLayout(
89 entry_computation_->ComputeProgramShape());
90 }
91 input_output_alias_config_ = HloInputOutputAliasConfig(
92 entry_computation_->root_instruction()->shape());
93 }
94
95 if (uniquify_identifiers) {
96 computation->UniquifyName(&computation_name_uniquer_);
97 for (auto* instruction : computation->instructions()) {
98 instruction->UniquifyName(&instruction_name_uniquer_);
99 }
100
101 // Pick unique IDs for each instruction.
102 for (auto* instruction : computation->instructions()) {
103 instruction->SetUniqueId(NewUniqueInstructionId());
104 }
105 // Set unique id to this computation.
106 CHECK_NE(computation->root_instruction()->unique_id(), -1)
107 << "Root has no valid id: " << computation->ToString();
108 computation->SetUniqueId(computation->root_instruction()->unique_id());
109 } else {
110 // Don't uniquify the names of the computation or instruction, but we must
111 // run the names through the uniquifiers to prevent future name collisions
112 // for computations and instructions created later. Also, set the
113 // next_unique_id_ to the one greater than the max unique id of any
114 // instruction (or the computation) to avoid ID collisions.
115 computation_name_uniquer_.GetUniqueName(computation->name());
116 for (auto* instruction : computation->instructions()) {
117 instruction_name_uniquer_.GetUniqueName(instruction->name());
118 next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
119 }
120 if (next_unique_id_ < computation->unique_id() + 1) {
121 next_unique_id_ = computation->unique_id() + 1;
122 }
123 }
124
125 computation->set_parent(this);
126 computations_.push_back(std::move(computation));
127 return computations_.back().get();
128 }
129
AddEntryComputation(std::unique_ptr<HloComputation> computation)130 HloComputation* HloModule::AddEntryComputation(
131 std::unique_ptr<HloComputation> computation) {
132 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
133 /*uniquify_identifiers=*/true,
134 /*preserve_entry_layouts=*/false);
135 }
136
AddEntryComputationWithLayouts(std::unique_ptr<HloComputation> computation)137 HloComputation* HloModule::AddEntryComputationWithLayouts(
138 std::unique_ptr<HloComputation> computation) {
139 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
140 /*uniquify_identifiers=*/true,
141 /*preserve_entry_layouts=*/true);
142 }
143
RemoveEmbeddedComputation(HloComputation * to_remove)144 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
145 if (has_schedule() && !to_remove->IsCalledComputation()) {
146 schedule_->remove_computation(to_remove);
147 }
148
149 auto it = absl::c_find_if(
150 computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
151 return comp.get() == to_remove;
152 });
153 TF_RET_CHECK(it != computations_.end());
154 TF_RET_CHECK(it->get() == to_remove);
155 computations_.erase(it);
156 return OkStatus();
157 }
158
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)159 HloComputation* HloModule::AddEmbeddedComputation(
160 std::unique_ptr<HloComputation> computation) {
161 return AddComputationInternal(std::move(computation), /*is_entry=*/false,
162 /*uniquify_identifiers=*/true,
163 /*preserve_entry_layouts=*/false);
164 }
165
ReplaceComputations(const absl::flat_hash_map<HloComputation *,HloComputation * > & replacements)166 void HloModule::ReplaceComputations(
167 const absl::flat_hash_map<HloComputation*, HloComputation*>& replacements) {
168 // Replace all uses of non-canonical computations with their
169 // representatives.
170 std::vector<std::unique_ptr<HloComputation>> new_computations;
171 new_computations.reserve(computations_.size());
172
173 for (std::unique_ptr<HloComputation>& computation : computations_) {
174 for (auto* instruction : computation->instructions()) {
175 switch (instruction->opcode()) {
176 case HloOpcode::kAllReduce:
177 case HloOpcode::kCall:
178 case HloOpcode::kMap:
179 case HloOpcode::kReduce:
180 case HloOpcode::kReduceScatter:
181 case HloOpcode::kReduceWindow:
182 case HloOpcode::kScatter:
183 case HloOpcode::kSort: {
184 HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
185 replacements, instruction->to_apply(), nullptr);
186 if (new_arg != nullptr) {
187 instruction->set_to_apply(new_arg);
188 }
189 break;
190 }
191 case HloOpcode::kWhile: {
192 HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
193 replacements, instruction->while_condition(), nullptr);
194 if (new_condition != nullptr) {
195 instruction->set_while_condition(new_condition);
196 }
197 HloComputation* new_body = tensorflow::gtl::FindWithDefault(
198 replacements, instruction->while_body(), nullptr);
199 if (new_body != nullptr) {
200 instruction->set_while_body(new_body);
201 }
202 break;
203 }
204 case HloOpcode::kConditional: {
205 for (int b = 0; b < instruction->branch_count(); ++b) {
206 HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
207 replacements, instruction->branch_computation(b), nullptr);
208 if (new_computation != nullptr) {
209 instruction->set_branch_computation(b, new_computation);
210 }
211 }
212 break;
213 }
214 case HloOpcode::kSelectAndScatter: {
215 HloComputation* new_select = tensorflow::gtl::FindWithDefault(
216 replacements, instruction->select(), nullptr);
217 if (new_select != nullptr) {
218 instruction->set_select(new_select);
219 }
220 HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
221 replacements, instruction->scatter(), nullptr);
222 if (new_scatter != nullptr) {
223 instruction->set_scatter(new_scatter);
224 }
225 break;
226 }
227 default:
228 break;
229 }
230 }
231
232 if (replacements.find(computation.get()) == replacements.end()) {
233 new_computations.push_back(std::move(computation));
234 }
235 }
236
237 // Replace entry_computation if necessary.
238 entry_computation_ = tensorflow::gtl::FindWithDefault(
239 replacements, entry_computation_, entry_computation_);
240
241 computations_ = std::move(new_computations);
242 }
243
ToString(const HloPrintOptions & options) const244 std::string HloModule::ToString(const HloPrintOptions& options) const {
245 return std::string(ToCord(options));
246 }
247
ToCord(const HloPrintOptions & options) const248 absl::Cord HloModule::ToCord(const HloPrintOptions& options) const {
249 absl::Cord result;
250 result.Append("HloModule ");
251 if (options.print_ids()) {
252 // When print_ids() is false, exclude module's name because it includes and
253 // leads to non-deterministic fingerprint.
254 result.Append(name());
255 }
256 if (has_schedule()) {
257 TF_CHECK_OK(schedule().Verify());
258 result.Append(", is_scheduled=true");
259 }
260 std::string serialized_aliasing = input_output_alias_config().ToShortString();
261 if (!serialized_aliasing.empty()) {
262 result.Append(", input_output_alias={ ");
263 result.Append(std::move(serialized_aliasing));
264 result.Append(" }");
265 }
266 if (config_.alias_passthrough_params()) {
267 result.Append(", alias_passthrough_params=true");
268 }
269 if (config_.has_entry_computation_layout()) {
270 result.Append(", entry_computation_layout={");
271 result.Append(entry_computation_layout().ToString());
272 result.Append("}");
273 }
274 if (config_.allow_spmd_sharding_propagation_to_output()) {
275 result.Append(", allow_spmd_sharding_propagation_to_output=true");
276 }
277 result.Append("\n\n");
278 const auto& computations = options.canonicalize_computations()
279 ? MakeComputationSorted()
280 : MakeComputationPostOrder();
281 for (const HloComputation* computation : computations) {
282 // Don't print async computations when the sytax sugar is enabled since that
283 // is redundant information.
284 if (options.syntax_sugar_async_ops() && computation->IsAsyncComputation()) {
285 continue;
286 }
287 if (computation == entry_computation()) {
288 result.Append("ENTRY ");
289 }
290 if (has_schedule() && schedule().is_computation_scheduled(computation)) {
291 result.Append(computation->ToCord(
292 options, schedule().sequence(computation).instructions()));
293 } else {
294 result.Append(computation->ToCord(options));
295 }
296 result.Append("\n\n");
297 }
298 return result;
299 }
300
ToProto() const301 HloModuleProto HloModule::ToProto() const {
302 HloModuleProto proto;
303 proto.set_id(unique_id_);
304 proto.set_name(name_);
305 if (entry_computation_) {
306 proto.set_entry_computation_name(entry_computation_->name());
307 proto.set_entry_computation_id(entry_computation_->unique_id());
308 *proto.mutable_host_program_shape() =
309 entry_computation_layout().ComputeProgramShape().ToProto();
310 }
311 for (const HloComputation* computation : MakeComputationPostOrder()) {
312 HloComputationProto computation_proto = computation->ToProto();
313 proto.add_computations()->Swap(&computation_proto);
314 }
315 if (has_schedule()) {
316 *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
317 }
318 *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
319 *proto.mutable_dynamic_parameter_binding() =
320 dynamic_parameter_binding().ToProto();
321 for (const auto& parameter_indices : CrossProgramPrefetches()) {
322 const auto& parameter = parameter_indices.first;
323 const auto& indices = parameter_indices.second;
324 auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
325 prefetch->set_parameter(parameter);
326 for (auto index : indices) {
327 prefetch->add_index(index);
328 }
329 }
330 proto.set_is_dynamic(is_dynamic_);
331 if (has_spmd_output_sharding()) {
332 *proto.mutable_spmd_output_sharding() = spmd_output_sharding().ToProto();
333 }
334
335 if (has_spmd_parameters_shardings()) {
336 for (const auto& parameter_sharding : spmd_parameters_shardings()) {
337 *proto.add_spmd_parameters_shardings() = parameter_sharding.ToProto();
338 }
339 }
340
341 proto.set_use_auto_spmd_partitioning(use_auto_spmd_partitioning_);
342
343 for (const HloModuleProto::ProfileInfo& profile_info : profile_info_list_) {
344 HloModuleProto::ProfileInfo& profile_info_proto =
345 *proto.mutable_profile_info()->Add();
346 profile_info_proto.set_profile_type(profile_info.profile_type());
347 profile_info_proto.set_relative_speedup(profile_info.relative_speedup());
348 profile_info_proto.set_profile_source(profile_info.profile_source());
349 profile_info_proto.set_compilation_event(profile_info.compilation_event());
350 }
351 if (this->config_.has_static_device_assignment()) {
352 DeviceAssignmentProto device_assignment;
353 TF_CHECK_OK(
354 this->config_.static_device_assignment().Serialize(&device_assignment));
355 (*proto.mutable_device_assignment()) = device_assignment;
356 }
357 return proto;
358 }
359
CheckUniqueNamesAndIdsForComputationsAndInstructions() const360 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
361 absl::flat_hash_set<std::string> computation_names;
362 absl::flat_hash_set<int> computation_ids;
363 absl::flat_hash_set<std::string> instruction_names;
364 absl::flat_hash_set<int> instruction_ids;
365
366 for (const HloComputation* computation : computations()) {
367 TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
368 << "Computation name is not unique: " << computation->name();
369 computation_names.insert(computation->name());
370
371 TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
372 << "Computation id is not unique: " << computation->unique_id();
373 computation_ids.insert(computation->unique_id());
374
375 for (const HloInstruction* instruction : computation->instructions()) {
376 TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
377 << "Instruction name is not unique: " << instruction->name();
378 instruction_names.insert(instruction->name());
379
380 TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
381 << "Instruction id is not unique: " << instruction->unique_id();
382 instruction_ids.insert(instruction->unique_id());
383 }
384 }
385 return OkStatus();
386 }
387
388 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)389 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
390 const HloModuleProto& proto, const HloModuleConfig& module_config,
391 bool prohibit_empty_literal) {
392 VLOG(2) << "CreateFromProto()";
393 XLA_VLOG_LINES(3, proto.DebugString());
394
395 // The ProgramShape in the passed in module config must match the shapes of
396 // the entry parameters and root.
397 TF_RET_CHECK(proto.has_host_program_shape())
398 << "No program shape found in the proto";
399 ProgramShape expected_program_shape(proto.host_program_shape());
400 TF_RET_CHECK(expected_program_shape.parameters_size() ==
401 module_config.entry_computation_layout().parameter_count());
402 for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
403 const Shape& parameter_shape =
404 module_config.entry_computation_layout().parameter_layout(i).shape();
405 TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
406 parameter_shape))
407 << "HloModuleConfig has different shape for parameter " << i
408 << " than the HLO module. Expected: "
409 << ShapeUtil::HumanStringWithLayout(
410 expected_program_shape.parameters(i))
411 << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
412 }
413 const Shape& result_shape =
414 module_config.entry_computation_layout().result_layout().shape();
415 TF_RET_CHECK(
416 ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
417 << "HloModuleConfig has different result shape than the HLO module. "
418 "Expected: "
419 << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
420 << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
421
422 absl::flat_hash_map<int64_t, HloComputation*> computation_map;
423 absl::flat_hash_map<HloComputation*, int64_t> to_proto_id;
424 std::vector<std::unique_ptr<HloComputation>> computations;
425 HloComputation* entry = nullptr;
426 for (const HloComputationProto& computation_proto : proto.computations()) {
427 TF_ASSIGN_OR_RETURN(
428 std::unique_ptr<HloComputation> computation,
429 HloComputation::CreateFromProto(computation_proto, computation_map,
430 prohibit_empty_literal));
431 CHECK_NE(computation.get(), nullptr);
432 int64_t computation_id = computation_proto.id();
433 TF_RET_CHECK(computation_id != -1);
434 TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
435 computation_map[computation_id] = computation.get();
436 to_proto_id[computation.get()] = computation_id;
437 if (computation_id == proto.entry_computation_id()) {
438 entry = computation.get();
439 }
440 computations.push_back(std::move(computation));
441 }
442 TF_RET_CHECK(entry != nullptr);
443
444 auto module = std::make_unique<HloModule>(proto.name(), module_config);
445
446 // Sort the computations in the proto id's order.
447 absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
448 const std::unique_ptr<HloComputation>& b) {
449 return to_proto_id[a.get()] < to_proto_id[b.get()];
450 });
451
452 // Add sorted computations to the module.
453 for (auto& computation : computations) {
454 bool is_entry = computation.get() == entry;
455 // Don't uniquify names because we want names to be stable across
456 // serialization and deserialization.
457 module->AddComputationInternal(std::move(computation), is_entry,
458 /*uniquify_identifiers=*/false,
459 /*preserve_entry_layouts=*/false);
460 }
461 TF_RET_CHECK(module->entry_computation_ != nullptr);
462 TF_ASSIGN_OR_RETURN(
463 module->input_output_alias_config_,
464 HloInputOutputAliasConfig::CreateFromProto(
465 entry->ComputeProgramShape().result(), proto.input_output_alias()));
466
467 // Because we didn't uniquify the names or the ids, double-check that the
468 // instruction and computation names and ids are unique from the proto.
469 TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
470 DynamicParameterBinding::CreateFromProto(
471 proto.dynamic_parameter_binding()));
472
473 TF_RETURN_IF_ERROR(
474 module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
475
476 if (proto.has_schedule()) {
477 TF_ASSIGN_OR_RETURN(
478 HloSchedule schedule,
479 HloSchedule::CreateFromProto(module.get(), proto.schedule()));
480 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
481 }
482
483 for (const auto& prefetch : proto.cross_program_prefetches()) {
484 module->AddCrossProgramPrefetch(
485 prefetch.parameter(),
486 ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
487 }
488
489 module->set_is_dynamic(proto.is_dynamic());
490
491 if (proto.has_spmd_output_sharding()) {
492 TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
493 HloSharding::FromProto(proto.spmd_output_sharding()));
494 module->set_spmd_output_sharding(hlo_sharding);
495 }
496
497 std::vector<HloSharding> param_shardings;
498 for (const auto& sharding_proto : proto.spmd_parameters_shardings()) {
499 TF_ASSIGN_OR_RETURN(HloSharding sharding,
500 HloSharding::FromProto(sharding_proto));
501 param_shardings.push_back(sharding);
502 }
503 if (!param_shardings.empty()) {
504 module->set_spmd_parameters_shardings(param_shardings);
505 }
506
507 module->set_use_auto_spmd_partitioning(proto.use_auto_spmd_partitioning());
508
509 for (const auto& profile_info : proto.profile_info()) {
510 module->add_profile_info(profile_info);
511 }
512 if (proto.has_device_assignment()) {
513 if (!module->config_.has_static_device_assignment()) {
514 TF_ASSIGN_OR_RETURN(
515 std::unique_ptr<DeviceAssignment> device_assignment,
516 DeviceAssignment::Deserialize(proto.device_assignment()));
517 module->config_.set_static_device_assignment(*device_assignment);
518 }
519 }
520 return std::move(module);
521 }
522
523 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)524 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
525 const ProgramShape& program_shape, const DebugOptions& debug_options,
526 const ExecutionOptions* execution_options) {
527 HloModuleConfig module_config(ProgramShape{program_shape});
528 module_config.set_debug_options(debug_options);
529 if (execution_options) {
530 if (execution_options->num_replicas() > 0) {
531 module_config.set_replica_count(execution_options->num_replicas());
532 }
533 if (execution_options->num_partitions() > 0) {
534 module_config.set_num_partitions(execution_options->num_partitions());
535 }
536 module_config.set_use_spmd_partitioning(
537 execution_options->use_spmd_partitioning());
538 module_config.set_use_auto_spmd_partitioning(
539 execution_options->use_auto_spmd_partitioning());
540 std::vector<int64_t> mesh_shape;
541 for (auto t : execution_options->auto_spmd_partitioning_mesh_shape()) {
542 mesh_shape.push_back(t);
543 }
544 module_config.set_auto_spmd_partitioning_mesh_shape(mesh_shape);
545 std::vector<int64_t> mesh_ids;
546 for (auto t : execution_options->auto_spmd_partitioning_mesh_ids()) {
547 mesh_ids.push_back(t);
548 }
549 module_config.set_auto_spmd_partitioning_mesh_ids(mesh_ids);
550 module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
551 module_config.set_allow_spmd_sharding_propagation_to_output(
552 execution_options->allow_spmd_sharding_propagation_to_output());
553 if (execution_options->has_device_assignment()) {
554 TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
555 DeviceAssignment::Deserialize(
556 execution_options->device_assignment()));
557 module_config.set_static_device_assignment(*device_assignment);
558 if (execution_options->num_replicas() > 0) {
559 CHECK_EQ(module_config.static_device_assignment().replica_count(),
560 module_config.replica_count());
561 }
562 if (execution_options->num_partitions() > 0) {
563 CHECK_EQ(module_config.static_device_assignment().computation_count(),
564 module_config.num_partitions());
565 }
566 }
567 }
568
569 // The module config is constructed with default layouts regardless of what is
570 // passed in via the ProgramShape. Set the layouts to the appropriate values.
571 ComputationLayout* entry_layout =
572 module_config.mutable_entry_computation_layout();
573 for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) {
574 TF_RETURN_IF_ERROR(
575 entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
576 program_shape.parameters(i)));
577 }
578 TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
579 program_shape.result()));
580 return module_config;
581 }
582
583 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)584 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
585 const HloModuleProto& module, const DebugOptions& debug_options,
586 const ExecutionOptions* execution_options) {
587 if (!module.has_host_program_shape()) {
588 return tensorflow::errors::FailedPrecondition(
589 "No program shape found in the proto");
590 }
591 ProgramShape program_shape(module.host_program_shape());
592 TF_ASSIGN_OR_RETURN(HloModuleConfig config,
593 CreateModuleConfigFromShape(program_shape, debug_options,
594 execution_options));
595 if (!config.has_static_device_assignment()) {
596 if (module.has_device_assignment()) {
597 // Get the proto from the exeuction options rather than the module proto.
598 TF_ASSIGN_OR_RETURN(
599 std::unique_ptr<DeviceAssignment> device_assignment,
600 DeviceAssignment::Deserialize(module.device_assignment()));
601 config.set_static_device_assignment(*device_assignment);
602 }
603 }
604 return config;
605 }
606
607 namespace {
608 // Returns whether `hlo` is used outside the given subcomputation.
609 // `instructions_in_subcomputation` is the instruction set of the given
610 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)611 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
612 const absl::flat_hash_set<HloInstruction*>&
613 instructions_in_subcomputation) {
614 return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
615 return !instructions_in_subcomputation.contains(user);
616 });
617 }
618 } // anonymous namespace
619
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const std::string & outlined_computation_name,HloComputation * computation)620 HloInstruction* HloModule::OutlineExpressionFromComputation(
621 absl::Span<HloInstruction* const> instructions_to_outline,
622 const std::string& outlined_computation_name, HloComputation* computation) {
623 auto builder = HloComputation::Builder(outlined_computation_name);
624
625 // A map from original instructions to their counterparts in the new outlined
626 // function.
627 absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
628 // A set that contains all instructions to be outlined.
629 absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
630 instructions_to_outline.begin(), instructions_to_outline.end());
631 std::vector<HloInstruction*> arguments;
632 std::vector<HloInstruction*> outputs;
633 int64_t parameter_count = 0;
634 for (HloInstruction* instruction_to_outline : instructions_to_outline) {
635 // Clone the original instruction.
636 HloInstruction* outlined_instruction =
637 builder.AddInstruction(instruction_to_outline->Clone());
638
639 // Replace its operands to their counterparts in the new function.
640 for (int64_t operand_num = 0;
641 operand_num < outlined_instruction->operand_count(); ++operand_num) {
642 HloInstruction* old_operand =
643 outlined_instruction->mutable_operand(operand_num);
644
645 HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
646 if (*operand_slot == nullptr) {
647 // Because instructions_to_outline is in topological order, if
648 // old_operand is not in outlined_instructions, old_operand must be an
649 // input of the outlined subcomputation and thus should be represented
650 // as a parameter in the new function.
651 arguments.push_back(old_operand);
652 *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
653 parameter_count, old_operand->shape(), "p"));
654 ++parameter_count;
655 }
656 TF_CHECK_OK(
657 outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
658 }
659
660 // Insert the new instruction into the outlined_instructions map.
661 InsertOrDie(&outlined_instructions, instruction_to_outline,
662 outlined_instruction);
663
664 // Mark instruction_to_outline an output if it is used outside the
665 // subcomputation or is the output of the original computation (i.e. used
666 // externally).
667 if (instruction_to_outline->user_count() == 0 ||
668 IsUsedOutsideSubcomputation(*instruction_to_outline,
669 instruction_set_to_outline)) {
670 outputs.push_back(instruction_to_outline);
671 }
672 }
673
674 if (outputs.size() != 1) {
675 std::string error_message =
676 "The subcomputation to outline has multiple outputs:\n";
677 for (HloInstruction* output : outputs) {
678 absl::StrAppend(&error_message, output->ToString(), "\n");
679 }
680 LOG(FATAL) << error_message;
681 }
682 HloInstruction* output = outputs[0];
683
684 // Creates a call to the nested computation.
685 HloComputation* nested_computation = AddEmbeddedComputation(
686 builder.Build(FindOrDie(outlined_instructions, output)));
687 HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
688 output->shape(), arguments, nested_computation));
689
690 VLOG(2) << "Outlining the following instructions";
691 for (auto* instruction_to_outline : instructions_to_outline) {
692 VLOG(2) << " " << instruction_to_outline->ToString();
693 }
694 VLOG(2) << "as a call " << call->ToString();
695 VLOG(2) << "to " << nested_computation->ToString();
696
697 TF_CHECK_OK(output->ReplaceAllUsesWith(call));
698 for (auto i = instructions_to_outline.rbegin();
699 i != instructions_to_outline.rend(); ++i) {
700 TF_CHECK_OK(computation->RemoveInstruction(*i));
701 }
702
703 return call;
704 }
705
instruction_count() const706 int64_t HloModule::instruction_count() const {
707 int64_t n = 0;
708 for (const auto& computation : computations_) {
709 n += computation->instruction_count();
710 }
711 return n;
712 }
713
MakeComputationPostOrder(const absl::flat_hash_set<absl::string_view> & execution_threads,const absl::flat_hash_set<HloComputation * > & allow_list) const714 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
715 const absl::flat_hash_set<absl::string_view>& execution_threads,
716 const absl::flat_hash_set<HloComputation*>& allow_list) const {
717 std::vector<HloComputation*> filtered_post_order(allow_list.size());
718 auto post_order = this->MakeComputationPostOrder(execution_threads);
719
720 int filtered_idx = 0;
721 for (auto& computation : post_order) {
722 if (allow_list.contains(computation)) {
723 filtered_post_order[filtered_idx] = computation;
724 filtered_idx += 1;
725 }
726 }
727
728 return filtered_post_order;
729 }
730
MakeComputationPostOrder(const absl::flat_hash_set<absl::string_view> & execution_threads) const731 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
732 const absl::flat_hash_set<absl::string_view>& execution_threads) const {
733 if (computations_.empty()) {
734 return {};
735 }
736 // First determine all root computations by building a set of nonroot
737 // computations (computations which are called by an instruction in the
738 // module).
739 absl::flat_hash_set<HloComputation*> nonroot_computations;
740 nonroot_computations.reserve(computations_.size() - 1);
741 for (auto& computation : computations_) {
742 for (auto* instruction : computation->instructions()) {
743 for (HloComputation* called_computation :
744 instruction->called_computations()) {
745 nonroot_computations.insert(called_computation);
746 }
747 }
748 }
749
750 // Keep track of computations which have already been added to the post
751 // order. This prevents duplication as an embedded computation may be called
752 // from two different root computations.
753 absl::flat_hash_set<HloComputation*> added_computations;
754 std::vector<HloComputation*> post_order;
755 added_computations.reserve(computations_.size());
756 post_order.reserve(computations_.size());
757 for (auto& computation : computations_) {
758 if (nonroot_computations.contains(computation.get())) {
759 continue;
760 }
761 for (HloComputation* embedded_computation :
762 computation->MakeEmbeddedComputationsList()) {
763 if (!added_computations.contains(embedded_computation)) {
764 post_order.push_back(embedded_computation);
765 added_computations.insert(embedded_computation);
766 }
767 }
768 // Root computations should only be encountered once.
769 CHECK(!added_computations.contains(computation.get()));
770 post_order.push_back(computation.get());
771 added_computations.insert(computation.get());
772 }
773 if (post_order.size() != computations_.size()) {
774 for (HloComputation* computation : post_order) {
775 LOG(ERROR) << "Post Order: " << computation->name() << " ("
776 << computation->parent()->name() << ")";
777 }
778 for (auto& computation : computations_) {
779 LOG(ERROR) << "Computations: " << computation->name() << " ("
780 << computation->parent()->name() << ")";
781 }
782 LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
783 << " computation_count=" << computations_.size();
784 }
785 if (execution_threads.empty()) {
786 return post_order;
787 }
788 std::vector<HloComputation*> post_order_with_execution_threads;
789 absl::c_copy_if(
790 post_order, std::back_inserter(post_order_with_execution_threads),
791 [&execution_threads](HloComputation* computation) {
792 return execution_threads.find(computation->execution_thread()) !=
793 execution_threads.end();
794 });
795 return post_order_with_execution_threads;
796 }
797
798 namespace {
799
800 class FingerprintMap {
801 public:
Reserve(int capacity)802 void Reserve(int capacity) { fingerprint_map_.reserve(capacity); }
803
GetFingerprint(const HloComputation * computation)804 uint64_t GetFingerprint(const HloComputation* computation) {
805 auto result = fingerprint_map_.try_emplace(computation, 0);
806 if (result.second) {
807 result.first->second =
808 tensorflow::Fingerprint64(computation->ToString(print_options_));
809 }
810 return result.first->second;
811 }
812
813 private:
814 HloPrintOptions print_options_ = HloPrintOptions::ModuleFingerprint();
815 absl::flat_hash_map<const HloComputation*, uint64_t> fingerprint_map_;
816 };
817
SortComputationsByContent(std::vector<HloComputation * > * computations)818 void SortComputationsByContent(std::vector<HloComputation*>* computations) {
819 FingerprintMap fingerprint_map;
820 fingerprint_map.Reserve(computations->size());
821 auto cmp = [&fingerprint_map](const HloComputation* a,
822 const HloComputation* b) {
823 if (a->instruction_count() != b->instruction_count()) {
824 return a->instruction_count() < b->instruction_count();
825 }
826 return fingerprint_map.GetFingerprint(a) <
827 fingerprint_map.GetFingerprint(b);
828 };
829 absl::c_sort(*computations, cmp);
830 }
831
832 } // anonymous namespace
833
MakeComputationSorted(const absl::flat_hash_set<absl::string_view> & execution_threads) const834 std::vector<HloComputation*> HloModule::MakeComputationSorted(
835 const absl::flat_hash_set<absl::string_view>& execution_threads) const {
836 std::vector<HloComputation*> result =
837 MakeComputationPostOrder(execution_threads);
838 if (config().content_aware_computation_sorting()) {
839 SortComputationsByContent(&result);
840 }
841 return result;
842 }
843
MakeNonfusionComputations(const absl::flat_hash_set<absl::string_view> & execution_threads) const844 std::vector<HloComputation*> HloModule::MakeNonfusionComputations(
845 const absl::flat_hash_set<absl::string_view>& execution_threads) const {
846 std::vector<HloComputation*> result =
847 MakeComputationPostOrder(execution_threads);
848 result.erase(std::remove_if(
849 result.begin(), result.end(),
850 [](HloComputation* c) { return c->IsFusionComputation(); }),
851 result.end());
852 return result;
853 }
854
MakeNonfusionComputationsSorted(const absl::flat_hash_set<absl::string_view> & execution_threads) const855 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted(
856 const absl::flat_hash_set<absl::string_view>& execution_threads) const {
857 auto result = MakeNonfusionComputations(execution_threads);
858 if (config().content_aware_computation_sorting()) {
859 SortComputationsByContent(&result);
860 }
861 return result;
862 }
863
Clone(const std::string & suffix) const864 std::unique_ptr<HloModule> HloModule::Clone(const std::string& suffix) const {
865 return Clone(config(), suffix);
866 }
867
Clone(const HloModuleConfig & config,const std::string & suffix) const868 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
869 const std::string& suffix) const {
870 VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
871 auto module = absl::WrapUnique(new HloModule(
872 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config,
873 std::make_unique<CompilationEnvironments>(*comp_envs_)));
874
875 HloCloneContext context(module.get(), suffix);
876 auto cloned_computation = entry_computation_->Clone(suffix, &context);
877 module->AddEntryComputation(std::move(cloned_computation));
878 module->input_output_alias_config() = input_output_alias_config();
879 module->set_is_dynamic(is_dynamic());
880 if (has_schedule() && schedule().Verify().ok()) {
881 HloSchedule clone_schedule(module.get());
882 for (HloComputation* computation : computations()) {
883 if (schedule().is_computation_scheduled(computation)) {
884 HloComputation* new_computation = context.FindComputation(computation);
885 // The module being cloned may have computations that are dead, i.e.,
886 // unreachable from the entry computation. In that case, new_computation
887 // is nullptr.
888 if (new_computation != nullptr) {
889 HloInstructionSequence& clone_sequence =
890 clone_schedule.GetOrCreateSequence(new_computation);
891 for (const HloInstruction* instruction :
892 schedule().sequence(computation).instructions()) {
893 clone_sequence.push_back(context.GetInstruction(instruction));
894 }
895 }
896 }
897 }
898 TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
899 }
900 for (const auto& parameter_indices : CrossProgramPrefetches()) {
901 const auto& parameter = parameter_indices.first;
902 const auto& indices = parameter_indices.second;
903 module->AddCrossProgramPrefetch(parameter, indices);
904 }
905
906 // To make clone behavior match uncloned behavior, we reorder
907 // module->computations_ to match the order in computations_.
908 using ComputationSorter = MappedPtrContainerSorter<HloComputation>;
909 ComputationSorter::MapPtrFn computation_map_fn =
910 [&context](const HloComputation* c) {
911 return context.FindComputation(c);
912 };
913 auto status = ComputationSorter::Sort(
914 computation_map_fn, ComputationSorter::IndexAfterMappedElementsFn(),
915 computations_, module->computations_);
916 if (!status.ok()) {
917 LOG(ERROR) << "Failed to sort module computations for " << name() << "; "
918 << status;
919 }
920
921 return module;
922 }
923
RemoveUnusedComputations()924 Status HloModule::RemoveUnusedComputations() {
925 std::string suffix = "tmp";
926 auto module = std::make_unique<HloModule>(
927 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
928 HloCloneContext context(module.get(), suffix);
929 entry_computation_->Clone(suffix, &context);
930 std::vector<HloComputation*> to_remove;
931 for (auto computation : computations()) {
932 auto found_computation = context.FindComputation(computation);
933 if (found_computation == nullptr) {
934 to_remove.push_back(computation);
935 }
936 }
937 for (auto computation : to_remove) {
938 TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
939 }
940 return OkStatus();
941 }
942
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)943 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
944 HloCloneContext* context) {
945 HloComputation* new_computation;
946 if (context != nullptr) {
947 if ((new_computation = context->FindComputation(computation)) != nullptr) {
948 return new_computation;
949 }
950 new_computation =
951 AddEmbeddedComputation(computation->Clone(context->suffix(), context));
952 } else {
953 new_computation = AddEmbeddedComputation(computation->Clone(""));
954 }
955 return new_computation;
956 }
957
RandomNew64() const958 uint64_t HloModule::RandomNew64() const {
959 absl::MutexLock l(&rng_mutex_);
960 return rng_();
961 }
962
GetComputationWithName(absl::string_view name)963 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
964 auto computations_in_module = computations();
965 auto it = absl::c_find_if(
966 computations_in_module,
967 [&](HloComputation* computation) { return computation->name() == name; });
968 return it == computations_in_module.end() ? nullptr : *it;
969 }
970
HloModule(const std::string & name,HloModuleConfig config,std::unique_ptr<CompilationEnvironments> comp_envs)971 HloModule::HloModule(const std::string& name, HloModuleConfig config,
972 std::unique_ptr<CompilationEnvironments> comp_envs)
973 : name_(NameUniquer::GetSanitizedName(name)),
974 config_(std::move(config)),
975 unique_id_(next_unique_module_id_++),
976 metadata_(tensorflow::Env::Default()),
977 comp_envs_(std::move(comp_envs)) {
978 metadata_.set_canonical_module_id(unique_id_);
979 }
980
981 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
982
983 } // namespace xla
984