• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <optional>
24 #include <set>
25 #include <sstream>
26 #include <string>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/algorithm/container.h"
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/flat_hash_set.h"
33 #include "absl/memory/memory.h"
34 #include "absl/strings/str_cat.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/service/compilation_environments.h"
37 #include "tensorflow/compiler/xla/service/computation_placer.h"
38 #include "tensorflow/compiler/xla/service/hlo.pb.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
42 #include "tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h"
43 #include "tensorflow/compiler/xla/shape_util.h"
44 #include "tensorflow/compiler/xla/status_macros.h"
45 #include "tensorflow/compiler/xla/types.h"
46 #include "tensorflow/compiler/xla/xla_data.pb.h"
47 #include "tensorflow/core/lib/core/errors.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/errors.h"
50 #include "tensorflow/core/platform/fingerprint.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/status.h"
53 #include "tensorflow/core/platform/statusor.h"
54 
55 namespace xla {
56 
HloModule(const std::string & name,HloModuleConfig config)57 HloModule::HloModule(const std::string& name, HloModuleConfig config)
58     : HloModule(name, config, std::make_unique<CompilationEnvironments>()) {}
59 
set_schedule(HloSchedule schedule)60 Status HloModule::set_schedule(HloSchedule schedule) {
61   TF_RET_CHECK(schedule.module() == this);
62   TF_RETURN_IF_ERROR(schedule.Verify());
63   schedule_ = std::move(schedule);
64   return OkStatus();
65 }
66 
ReplaceEntryComputation(HloComputation * entry_computation)67 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
68   entry_computation_ = entry_computation;
69   config_.SetDefaultComputationLayout(
70       entry_computation_->ComputeProgramShape());
71   input_output_alias_config_ = HloInputOutputAliasConfig(
72       entry_computation_->root_instruction()->shape());
73 }
74 
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers,bool preserve_entry_layouts)75 HloComputation* HloModule::AddComputationInternal(
76     std::unique_ptr<HloComputation> computation, bool is_entry,
77     bool uniquify_identifiers, bool preserve_entry_layouts) {
78   if (is_entry) {
79     CHECK_EQ(nullptr, entry_computation_);
80     entry_computation_ = computation.get();
81 
82     if (preserve_entry_layouts) {
83       config_.SetComputationLayoutIfExists(
84           entry_computation_->ComputeProgramShape());
85     } else if (!config_.has_entry_computation_layout()) {
86       // If the module configuration has no entry layout computation set, create
87       // a default one based on the program shape.
88       config_.SetDefaultComputationLayout(
89           entry_computation_->ComputeProgramShape());
90     }
91     input_output_alias_config_ = HloInputOutputAliasConfig(
92         entry_computation_->root_instruction()->shape());
93   }
94 
95   if (uniquify_identifiers) {
96     computation->UniquifyName(&computation_name_uniquer_);
97     for (auto* instruction : computation->instructions()) {
98       instruction->UniquifyName(&instruction_name_uniquer_);
99     }
100 
101     // Pick unique IDs for each instruction.
102     for (auto* instruction : computation->instructions()) {
103       instruction->SetUniqueId(NewUniqueInstructionId());
104     }
105     // Set unique id to this computation.
106     CHECK_NE(computation->root_instruction()->unique_id(), -1)
107         << "Root has no valid id: " << computation->ToString();
108     computation->SetUniqueId(computation->root_instruction()->unique_id());
109   } else {
110     // Don't uniquify the names of the computation or instruction, but we must
111     // run the names through the uniquifiers to prevent future name collisions
112     // for computations and instructions created later. Also, set the
113     // next_unique_id_ to the one greater than the max unique id of any
114     // instruction (or the computation) to avoid ID collisions.
115     computation_name_uniquer_.GetUniqueName(computation->name());
116     for (auto* instruction : computation->instructions()) {
117       instruction_name_uniquer_.GetUniqueName(instruction->name());
118       next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
119     }
120     if (next_unique_id_ < computation->unique_id() + 1) {
121       next_unique_id_ = computation->unique_id() + 1;
122     }
123   }
124 
125   computation->set_parent(this);
126   computations_.push_back(std::move(computation));
127   return computations_.back().get();
128 }
129 
AddEntryComputation(std::unique_ptr<HloComputation> computation)130 HloComputation* HloModule::AddEntryComputation(
131     std::unique_ptr<HloComputation> computation) {
132   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
133                                 /*uniquify_identifiers=*/true,
134                                 /*preserve_entry_layouts=*/false);
135 }
136 
AddEntryComputationWithLayouts(std::unique_ptr<HloComputation> computation)137 HloComputation* HloModule::AddEntryComputationWithLayouts(
138     std::unique_ptr<HloComputation> computation) {
139   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
140                                 /*uniquify_identifiers=*/true,
141                                 /*preserve_entry_layouts=*/true);
142 }
143 
RemoveEmbeddedComputation(HloComputation * to_remove)144 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
145   if (has_schedule() && !to_remove->IsCalledComputation()) {
146     schedule_->remove_computation(to_remove);
147   }
148 
149   auto it = absl::c_find_if(
150       computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
151         return comp.get() == to_remove;
152       });
153   TF_RET_CHECK(it != computations_.end());
154   TF_RET_CHECK(it->get() == to_remove);
155   computations_.erase(it);
156   return OkStatus();
157 }
158 
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)159 HloComputation* HloModule::AddEmbeddedComputation(
160     std::unique_ptr<HloComputation> computation) {
161   return AddComputationInternal(std::move(computation), /*is_entry=*/false,
162                                 /*uniquify_identifiers=*/true,
163                                 /*preserve_entry_layouts=*/false);
164 }
165 
ReplaceComputations(const absl::flat_hash_map<HloComputation *,HloComputation * > & replacements)166 void HloModule::ReplaceComputations(
167     const absl::flat_hash_map<HloComputation*, HloComputation*>& replacements) {
168   // Replace all uses of non-canonical computations with their
169   // representatives.
170   std::vector<std::unique_ptr<HloComputation>> new_computations;
171   new_computations.reserve(computations_.size());
172 
173   for (std::unique_ptr<HloComputation>& computation : computations_) {
174     for (auto* instruction : computation->instructions()) {
175       switch (instruction->opcode()) {
176         case HloOpcode::kAllReduce:
177         case HloOpcode::kCall:
178         case HloOpcode::kMap:
179         case HloOpcode::kReduce:
180         case HloOpcode::kReduceScatter:
181         case HloOpcode::kReduceWindow:
182         case HloOpcode::kScatter:
183         case HloOpcode::kSort: {
184           HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
185               replacements, instruction->to_apply(), nullptr);
186           if (new_arg != nullptr) {
187             instruction->set_to_apply(new_arg);
188           }
189           break;
190         }
191         case HloOpcode::kWhile: {
192           HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
193               replacements, instruction->while_condition(), nullptr);
194           if (new_condition != nullptr) {
195             instruction->set_while_condition(new_condition);
196           }
197           HloComputation* new_body = tensorflow::gtl::FindWithDefault(
198               replacements, instruction->while_body(), nullptr);
199           if (new_body != nullptr) {
200             instruction->set_while_body(new_body);
201           }
202           break;
203         }
204         case HloOpcode::kConditional: {
205           for (int b = 0; b < instruction->branch_count(); ++b) {
206             HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
207                 replacements, instruction->branch_computation(b), nullptr);
208             if (new_computation != nullptr) {
209               instruction->set_branch_computation(b, new_computation);
210             }
211           }
212           break;
213         }
214         case HloOpcode::kSelectAndScatter: {
215           HloComputation* new_select = tensorflow::gtl::FindWithDefault(
216               replacements, instruction->select(), nullptr);
217           if (new_select != nullptr) {
218             instruction->set_select(new_select);
219           }
220           HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
221               replacements, instruction->scatter(), nullptr);
222           if (new_scatter != nullptr) {
223             instruction->set_scatter(new_scatter);
224           }
225           break;
226         }
227         default:
228           break;
229       }
230     }
231 
232     if (replacements.find(computation.get()) == replacements.end()) {
233       new_computations.push_back(std::move(computation));
234     }
235   }
236 
237   // Replace entry_computation if necessary.
238   entry_computation_ = tensorflow::gtl::FindWithDefault(
239       replacements, entry_computation_, entry_computation_);
240 
241   computations_ = std::move(new_computations);
242 }
243 
ToString(const HloPrintOptions & options) const244 std::string HloModule::ToString(const HloPrintOptions& options) const {
245   return std::string(ToCord(options));
246 }
247 
ToCord(const HloPrintOptions & options) const248 absl::Cord HloModule::ToCord(const HloPrintOptions& options) const {
249   absl::Cord result;
250   result.Append("HloModule ");
251   if (options.print_ids()) {
252     // When print_ids() is false, exclude module's name because it includes and
253     // leads to non-deterministic fingerprint.
254     result.Append(name());
255   }
256   if (has_schedule()) {
257     TF_CHECK_OK(schedule().Verify());
258     result.Append(", is_scheduled=true");
259   }
260   std::string serialized_aliasing = input_output_alias_config().ToShortString();
261   if (!serialized_aliasing.empty()) {
262     result.Append(", input_output_alias={ ");
263     result.Append(std::move(serialized_aliasing));
264     result.Append(" }");
265   }
266   if (config_.alias_passthrough_params()) {
267     result.Append(", alias_passthrough_params=true");
268   }
269   if (config_.has_entry_computation_layout()) {
270     result.Append(", entry_computation_layout={");
271     result.Append(entry_computation_layout().ToString());
272     result.Append("}");
273   }
274   if (config_.allow_spmd_sharding_propagation_to_output()) {
275     result.Append(", allow_spmd_sharding_propagation_to_output=true");
276   }
277   result.Append("\n\n");
278   const auto& computations = options.canonicalize_computations()
279                                  ? MakeComputationSorted()
280                                  : MakeComputationPostOrder();
281   for (const HloComputation* computation : computations) {
282     // Don't print async computations when the sytax sugar is enabled since that
283     // is redundant information.
284     if (options.syntax_sugar_async_ops() && computation->IsAsyncComputation()) {
285       continue;
286     }
287     if (computation == entry_computation()) {
288       result.Append("ENTRY ");
289     }
290     if (has_schedule() && schedule().is_computation_scheduled(computation)) {
291       result.Append(computation->ToCord(
292           options, schedule().sequence(computation).instructions()));
293     } else {
294       result.Append(computation->ToCord(options));
295     }
296     result.Append("\n\n");
297   }
298   return result;
299 }
300 
ToProto() const301 HloModuleProto HloModule::ToProto() const {
302   HloModuleProto proto;
303   proto.set_id(unique_id_);
304   proto.set_name(name_);
305   if (entry_computation_) {
306     proto.set_entry_computation_name(entry_computation_->name());
307     proto.set_entry_computation_id(entry_computation_->unique_id());
308     *proto.mutable_host_program_shape() =
309         entry_computation_layout().ComputeProgramShape().ToProto();
310   }
311   for (const HloComputation* computation : MakeComputationPostOrder()) {
312     HloComputationProto computation_proto = computation->ToProto();
313     proto.add_computations()->Swap(&computation_proto);
314   }
315   if (has_schedule()) {
316     *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
317   }
318   *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
319   *proto.mutable_dynamic_parameter_binding() =
320       dynamic_parameter_binding().ToProto();
321   for (const auto& parameter_indices : CrossProgramPrefetches()) {
322     const auto& parameter = parameter_indices.first;
323     const auto& indices = parameter_indices.second;
324     auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
325     prefetch->set_parameter(parameter);
326     for (auto index : indices) {
327       prefetch->add_index(index);
328     }
329   }
330   proto.set_is_dynamic(is_dynamic_);
331   if (has_spmd_output_sharding()) {
332     *proto.mutable_spmd_output_sharding() = spmd_output_sharding().ToProto();
333   }
334 
335   if (has_spmd_parameters_shardings()) {
336     for (const auto& parameter_sharding : spmd_parameters_shardings()) {
337       *proto.add_spmd_parameters_shardings() = parameter_sharding.ToProto();
338     }
339   }
340 
341   proto.set_use_auto_spmd_partitioning(use_auto_spmd_partitioning_);
342 
343   for (const HloModuleProto::ProfileInfo& profile_info : profile_info_list_) {
344     HloModuleProto::ProfileInfo& profile_info_proto =
345         *proto.mutable_profile_info()->Add();
346     profile_info_proto.set_profile_type(profile_info.profile_type());
347     profile_info_proto.set_relative_speedup(profile_info.relative_speedup());
348     profile_info_proto.set_profile_source(profile_info.profile_source());
349     profile_info_proto.set_compilation_event(profile_info.compilation_event());
350   }
351   if (this->config_.has_static_device_assignment()) {
352     DeviceAssignmentProto device_assignment;
353     TF_CHECK_OK(
354         this->config_.static_device_assignment().Serialize(&device_assignment));
355     (*proto.mutable_device_assignment()) = device_assignment;
356   }
357   return proto;
358 }
359 
CheckUniqueNamesAndIdsForComputationsAndInstructions() const360 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
361   absl::flat_hash_set<std::string> computation_names;
362   absl::flat_hash_set<int> computation_ids;
363   absl::flat_hash_set<std::string> instruction_names;
364   absl::flat_hash_set<int> instruction_ids;
365 
366   for (const HloComputation* computation : computations()) {
367     TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
368         << "Computation name is not unique: " << computation->name();
369     computation_names.insert(computation->name());
370 
371     TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
372         << "Computation id is not unique: " << computation->unique_id();
373     computation_ids.insert(computation->unique_id());
374 
375     for (const HloInstruction* instruction : computation->instructions()) {
376       TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
377           << "Instruction name is not unique: " << instruction->name();
378       instruction_names.insert(instruction->name());
379 
380       TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
381           << "Instruction id is not unique: " << instruction->unique_id();
382       instruction_ids.insert(instruction->unique_id());
383     }
384   }
385   return OkStatus();
386 }
387 
388 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)389 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
390     const HloModuleProto& proto, const HloModuleConfig& module_config,
391     bool prohibit_empty_literal) {
392   VLOG(2) << "CreateFromProto()";
393   XLA_VLOG_LINES(3, proto.DebugString());
394 
395   // The ProgramShape in the passed in module config must match the shapes of
396   // the entry parameters and root.
397   TF_RET_CHECK(proto.has_host_program_shape())
398       << "No program shape found in the proto";
399   ProgramShape expected_program_shape(proto.host_program_shape());
400   TF_RET_CHECK(expected_program_shape.parameters_size() ==
401                module_config.entry_computation_layout().parameter_count());
402   for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
403     const Shape& parameter_shape =
404         module_config.entry_computation_layout().parameter_layout(i).shape();
405     TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
406                                        parameter_shape))
407         << "HloModuleConfig has different shape for parameter " << i
408         << " than the HLO module. Expected: "
409         << ShapeUtil::HumanStringWithLayout(
410                expected_program_shape.parameters(i))
411         << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
412   }
413   const Shape& result_shape =
414       module_config.entry_computation_layout().result_layout().shape();
415   TF_RET_CHECK(
416       ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
417       << "HloModuleConfig has different result shape than the HLO module. "
418          "Expected: "
419       << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
420       << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
421 
422   absl::flat_hash_map<int64_t, HloComputation*> computation_map;
423   absl::flat_hash_map<HloComputation*, int64_t> to_proto_id;
424   std::vector<std::unique_ptr<HloComputation>> computations;
425   HloComputation* entry = nullptr;
426   for (const HloComputationProto& computation_proto : proto.computations()) {
427     TF_ASSIGN_OR_RETURN(
428         std::unique_ptr<HloComputation> computation,
429         HloComputation::CreateFromProto(computation_proto, computation_map,
430                                         prohibit_empty_literal));
431     CHECK_NE(computation.get(), nullptr);
432     int64_t computation_id = computation_proto.id();
433     TF_RET_CHECK(computation_id != -1);
434     TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
435     computation_map[computation_id] = computation.get();
436     to_proto_id[computation.get()] = computation_id;
437     if (computation_id == proto.entry_computation_id()) {
438       entry = computation.get();
439     }
440     computations.push_back(std::move(computation));
441   }
442   TF_RET_CHECK(entry != nullptr);
443 
444   auto module = std::make_unique<HloModule>(proto.name(), module_config);
445 
446   // Sort the computations in the proto id's order.
447   absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
448                                  const std::unique_ptr<HloComputation>& b) {
449     return to_proto_id[a.get()] < to_proto_id[b.get()];
450   });
451 
452   // Add sorted computations to the module.
453   for (auto& computation : computations) {
454     bool is_entry = computation.get() == entry;
455     // Don't uniquify names because we want names to be stable across
456     // serialization and deserialization.
457     module->AddComputationInternal(std::move(computation), is_entry,
458                                    /*uniquify_identifiers=*/false,
459                                    /*preserve_entry_layouts=*/false);
460   }
461   TF_RET_CHECK(module->entry_computation_ != nullptr);
462   TF_ASSIGN_OR_RETURN(
463       module->input_output_alias_config_,
464       HloInputOutputAliasConfig::CreateFromProto(
465           entry->ComputeProgramShape().result(), proto.input_output_alias()));
466 
467   // Because we didn't uniquify the names or the ids, double-check that the
468   // instruction and computation names and ids are unique from the proto.
469   TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
470                       DynamicParameterBinding::CreateFromProto(
471                           proto.dynamic_parameter_binding()));
472 
473   TF_RETURN_IF_ERROR(
474       module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
475 
476   if (proto.has_schedule()) {
477     TF_ASSIGN_OR_RETURN(
478         HloSchedule schedule,
479         HloSchedule::CreateFromProto(module.get(), proto.schedule()));
480     TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
481   }
482 
483   for (const auto& prefetch : proto.cross_program_prefetches()) {
484     module->AddCrossProgramPrefetch(
485         prefetch.parameter(),
486         ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
487   }
488 
489   module->set_is_dynamic(proto.is_dynamic());
490 
491   if (proto.has_spmd_output_sharding()) {
492     TF_ASSIGN_OR_RETURN(HloSharding hlo_sharding,
493                         HloSharding::FromProto(proto.spmd_output_sharding()));
494     module->set_spmd_output_sharding(hlo_sharding);
495   }
496 
497   std::vector<HloSharding> param_shardings;
498   for (const auto& sharding_proto : proto.spmd_parameters_shardings()) {
499     TF_ASSIGN_OR_RETURN(HloSharding sharding,
500                         HloSharding::FromProto(sharding_proto));
501     param_shardings.push_back(sharding);
502   }
503   if (!param_shardings.empty()) {
504     module->set_spmd_parameters_shardings(param_shardings);
505   }
506 
507   module->set_use_auto_spmd_partitioning(proto.use_auto_spmd_partitioning());
508 
509   for (const auto& profile_info : proto.profile_info()) {
510     module->add_profile_info(profile_info);
511   }
512   if (proto.has_device_assignment()) {
513     if (!module->config_.has_static_device_assignment()) {
514       TF_ASSIGN_OR_RETURN(
515           std::unique_ptr<DeviceAssignment> device_assignment,
516           DeviceAssignment::Deserialize(proto.device_assignment()));
517       module->config_.set_static_device_assignment(*device_assignment);
518     }
519   }
520   return std::move(module);
521 }
522 
523 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)524 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
525     const ProgramShape& program_shape, const DebugOptions& debug_options,
526     const ExecutionOptions* execution_options) {
527   HloModuleConfig module_config(ProgramShape{program_shape});
528   module_config.set_debug_options(debug_options);
529   if (execution_options) {
530     if (execution_options->num_replicas() > 0) {
531       module_config.set_replica_count(execution_options->num_replicas());
532     }
533     if (execution_options->num_partitions() > 0) {
534       module_config.set_num_partitions(execution_options->num_partitions());
535     }
536     module_config.set_use_spmd_partitioning(
537         execution_options->use_spmd_partitioning());
538     module_config.set_use_auto_spmd_partitioning(
539         execution_options->use_auto_spmd_partitioning());
540     std::vector<int64_t> mesh_shape;
541     for (auto t : execution_options->auto_spmd_partitioning_mesh_shape()) {
542       mesh_shape.push_back(t);
543     }
544     module_config.set_auto_spmd_partitioning_mesh_shape(mesh_shape);
545     std::vector<int64_t> mesh_ids;
546     for (auto t : execution_options->auto_spmd_partitioning_mesh_ids()) {
547       mesh_ids.push_back(t);
548     }
549     module_config.set_auto_spmd_partitioning_mesh_ids(mesh_ids);
550     module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
551     module_config.set_allow_spmd_sharding_propagation_to_output(
552         execution_options->allow_spmd_sharding_propagation_to_output());
553     if (execution_options->has_device_assignment()) {
554       TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
555                           DeviceAssignment::Deserialize(
556                               execution_options->device_assignment()));
557       module_config.set_static_device_assignment(*device_assignment);
558       if (execution_options->num_replicas() > 0) {
559         CHECK_EQ(module_config.static_device_assignment().replica_count(),
560                  module_config.replica_count());
561       }
562       if (execution_options->num_partitions() > 0) {
563         CHECK_EQ(module_config.static_device_assignment().computation_count(),
564                  module_config.num_partitions());
565       }
566     }
567   }
568 
569   // The module config is constructed with default layouts regardless of what is
570   // passed in via the ProgramShape. Set the layouts to the appropriate values.
571   ComputationLayout* entry_layout =
572       module_config.mutable_entry_computation_layout();
573   for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) {
574     TF_RETURN_IF_ERROR(
575         entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
576             program_shape.parameters(i)));
577   }
578   TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
579       program_shape.result()));
580   return module_config;
581 }
582 
583 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)584 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
585     const HloModuleProto& module, const DebugOptions& debug_options,
586     const ExecutionOptions* execution_options) {
587   if (!module.has_host_program_shape()) {
588     return tensorflow::errors::FailedPrecondition(
589         "No program shape found in the proto");
590   }
591   ProgramShape program_shape(module.host_program_shape());
592   TF_ASSIGN_OR_RETURN(HloModuleConfig config,
593                       CreateModuleConfigFromShape(program_shape, debug_options,
594                                                   execution_options));
595   if (!config.has_static_device_assignment()) {
596     if (module.has_device_assignment()) {
597       // Get the proto from the exeuction options rather than the module proto.
598       TF_ASSIGN_OR_RETURN(
599           std::unique_ptr<DeviceAssignment> device_assignment,
600           DeviceAssignment::Deserialize(module.device_assignment()));
601       config.set_static_device_assignment(*device_assignment);
602     }
603   }
604   return config;
605 }
606 
607 namespace {
608 // Returns whether `hlo` is used outside the given subcomputation.
609 // `instructions_in_subcomputation` is the instruction set of the given
610 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)611 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
612                                  const absl::flat_hash_set<HloInstruction*>&
613                                      instructions_in_subcomputation) {
614   return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
615     return !instructions_in_subcomputation.contains(user);
616   });
617 }
618 }  // anonymous namespace
619 
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const std::string & outlined_computation_name,HloComputation * computation)620 HloInstruction* HloModule::OutlineExpressionFromComputation(
621     absl::Span<HloInstruction* const> instructions_to_outline,
622     const std::string& outlined_computation_name, HloComputation* computation) {
623   auto builder = HloComputation::Builder(outlined_computation_name);
624 
625   // A map from original instructions to their counterparts in the new outlined
626   // function.
627   absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
628   // A set that contains all instructions to be outlined.
629   absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
630       instructions_to_outline.begin(), instructions_to_outline.end());
631   std::vector<HloInstruction*> arguments;
632   std::vector<HloInstruction*> outputs;
633   int64_t parameter_count = 0;
634   for (HloInstruction* instruction_to_outline : instructions_to_outline) {
635     // Clone the original instruction.
636     HloInstruction* outlined_instruction =
637         builder.AddInstruction(instruction_to_outline->Clone());
638 
639     // Replace its operands to their counterparts in the new function.
640     for (int64_t operand_num = 0;
641          operand_num < outlined_instruction->operand_count(); ++operand_num) {
642       HloInstruction* old_operand =
643           outlined_instruction->mutable_operand(operand_num);
644 
645       HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
646       if (*operand_slot == nullptr) {
647         // Because instructions_to_outline is in topological order, if
648         // old_operand is not in outlined_instructions, old_operand must be an
649         // input of the outlined subcomputation and thus should be represented
650         // as a parameter in the new function.
651         arguments.push_back(old_operand);
652         *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
653             parameter_count, old_operand->shape(), "p"));
654         ++parameter_count;
655       }
656       TF_CHECK_OK(
657           outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
658     }
659 
660     // Insert the new instruction into the outlined_instructions map.
661     InsertOrDie(&outlined_instructions, instruction_to_outline,
662                 outlined_instruction);
663 
664     // Mark instruction_to_outline an output if it is used outside the
665     // subcomputation or is the output of the original computation (i.e. used
666     // externally).
667     if (instruction_to_outline->user_count() == 0 ||
668         IsUsedOutsideSubcomputation(*instruction_to_outline,
669                                     instruction_set_to_outline)) {
670       outputs.push_back(instruction_to_outline);
671     }
672   }
673 
674   if (outputs.size() != 1) {
675     std::string error_message =
676         "The subcomputation to outline has multiple outputs:\n";
677     for (HloInstruction* output : outputs) {
678       absl::StrAppend(&error_message, output->ToString(), "\n");
679     }
680     LOG(FATAL) << error_message;
681   }
682   HloInstruction* output = outputs[0];
683 
684   // Creates a call to the nested computation.
685   HloComputation* nested_computation = AddEmbeddedComputation(
686       builder.Build(FindOrDie(outlined_instructions, output)));
687   HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
688       output->shape(), arguments, nested_computation));
689 
690   VLOG(2) << "Outlining the following instructions";
691   for (auto* instruction_to_outline : instructions_to_outline) {
692     VLOG(2) << "  " << instruction_to_outline->ToString();
693   }
694   VLOG(2) << "as a call " << call->ToString();
695   VLOG(2) << "to " << nested_computation->ToString();
696 
697   TF_CHECK_OK(output->ReplaceAllUsesWith(call));
698   for (auto i = instructions_to_outline.rbegin();
699        i != instructions_to_outline.rend(); ++i) {
700     TF_CHECK_OK(computation->RemoveInstruction(*i));
701   }
702 
703   return call;
704 }
705 
instruction_count() const706 int64_t HloModule::instruction_count() const {
707   int64_t n = 0;
708   for (const auto& computation : computations_) {
709     n += computation->instruction_count();
710   }
711   return n;
712 }
713 
MakeComputationPostOrder(const absl::flat_hash_set<absl::string_view> & execution_threads,const absl::flat_hash_set<HloComputation * > & allow_list) const714 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
715     const absl::flat_hash_set<absl::string_view>& execution_threads,
716     const absl::flat_hash_set<HloComputation*>& allow_list) const {
717   std::vector<HloComputation*> filtered_post_order(allow_list.size());
718   auto post_order = this->MakeComputationPostOrder(execution_threads);
719 
720   int filtered_idx = 0;
721   for (auto& computation : post_order) {
722     if (allow_list.contains(computation)) {
723       filtered_post_order[filtered_idx] = computation;
724       filtered_idx += 1;
725     }
726   }
727 
728   return filtered_post_order;
729 }
730 
MakeComputationPostOrder(const absl::flat_hash_set<absl::string_view> & execution_threads) const731 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
732     const absl::flat_hash_set<absl::string_view>& execution_threads) const {
733   if (computations_.empty()) {
734     return {};
735   }
736   // First determine all root computations by building a set of nonroot
737   // computations (computations which are called by an instruction in the
738   // module).
739   absl::flat_hash_set<HloComputation*> nonroot_computations;
740   nonroot_computations.reserve(computations_.size() - 1);
741   for (auto& computation : computations_) {
742     for (auto* instruction : computation->instructions()) {
743       for (HloComputation* called_computation :
744            instruction->called_computations()) {
745         nonroot_computations.insert(called_computation);
746       }
747     }
748   }
749 
750   // Keep track of computations which have already been added to the post
751   // order. This prevents duplication as an embedded computation may be called
752   // from two different root computations.
753   absl::flat_hash_set<HloComputation*> added_computations;
754   std::vector<HloComputation*> post_order;
755   added_computations.reserve(computations_.size());
756   post_order.reserve(computations_.size());
757   for (auto& computation : computations_) {
758     if (nonroot_computations.contains(computation.get())) {
759       continue;
760     }
761     for (HloComputation* embedded_computation :
762          computation->MakeEmbeddedComputationsList()) {
763       if (!added_computations.contains(embedded_computation)) {
764         post_order.push_back(embedded_computation);
765         added_computations.insert(embedded_computation);
766       }
767     }
768     // Root computations should only be encountered once.
769     CHECK(!added_computations.contains(computation.get()));
770     post_order.push_back(computation.get());
771     added_computations.insert(computation.get());
772   }
773   if (post_order.size() != computations_.size()) {
774     for (HloComputation* computation : post_order) {
775       LOG(ERROR) << "Post Order: " << computation->name() << " ("
776                  << computation->parent()->name() << ")";
777     }
778     for (auto& computation : computations_) {
779       LOG(ERROR) << "Computations: " << computation->name() << " ("
780                  << computation->parent()->name() << ")";
781     }
782     LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
783                << " computation_count=" << computations_.size();
784   }
785   if (execution_threads.empty()) {
786     return post_order;
787   }
788   std::vector<HloComputation*> post_order_with_execution_threads;
789   absl::c_copy_if(
790       post_order, std::back_inserter(post_order_with_execution_threads),
791       [&execution_threads](HloComputation* computation) {
792         return execution_threads.find(computation->execution_thread()) !=
793                execution_threads.end();
794       });
795   return post_order_with_execution_threads;
796 }
797 
798 namespace {
799 
800 class FingerprintMap {
801  public:
Reserve(int capacity)802   void Reserve(int capacity) { fingerprint_map_.reserve(capacity); }
803 
GetFingerprint(const HloComputation * computation)804   uint64_t GetFingerprint(const HloComputation* computation) {
805     auto result = fingerprint_map_.try_emplace(computation, 0);
806     if (result.second) {
807       result.first->second =
808           tensorflow::Fingerprint64(computation->ToString(print_options_));
809     }
810     return result.first->second;
811   }
812 
813  private:
814   HloPrintOptions print_options_ = HloPrintOptions::ModuleFingerprint();
815   absl::flat_hash_map<const HloComputation*, uint64_t> fingerprint_map_;
816 };
817 
SortComputationsByContent(std::vector<HloComputation * > * computations)818 void SortComputationsByContent(std::vector<HloComputation*>* computations) {
819   FingerprintMap fingerprint_map;
820   fingerprint_map.Reserve(computations->size());
821   auto cmp = [&fingerprint_map](const HloComputation* a,
822                                 const HloComputation* b) {
823     if (a->instruction_count() != b->instruction_count()) {
824       return a->instruction_count() < b->instruction_count();
825     }
826     return fingerprint_map.GetFingerprint(a) <
827            fingerprint_map.GetFingerprint(b);
828   };
829   absl::c_sort(*computations, cmp);
830 }
831 
832 }  // anonymous namespace
833 
MakeComputationSorted(const absl::flat_hash_set<absl::string_view> & execution_threads) const834 std::vector<HloComputation*> HloModule::MakeComputationSorted(
835     const absl::flat_hash_set<absl::string_view>& execution_threads) const {
836   std::vector<HloComputation*> result =
837       MakeComputationPostOrder(execution_threads);
838   if (config().content_aware_computation_sorting()) {
839     SortComputationsByContent(&result);
840   }
841   return result;
842 }
843 
MakeNonfusionComputations(const absl::flat_hash_set<absl::string_view> & execution_threads) const844 std::vector<HloComputation*> HloModule::MakeNonfusionComputations(
845     const absl::flat_hash_set<absl::string_view>& execution_threads) const {
846   std::vector<HloComputation*> result =
847       MakeComputationPostOrder(execution_threads);
848   result.erase(std::remove_if(
849                    result.begin(), result.end(),
850                    [](HloComputation* c) { return c->IsFusionComputation(); }),
851                result.end());
852   return result;
853 }
854 
MakeNonfusionComputationsSorted(const absl::flat_hash_set<absl::string_view> & execution_threads) const855 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted(
856     const absl::flat_hash_set<absl::string_view>& execution_threads) const {
857   auto result = MakeNonfusionComputations(execution_threads);
858   if (config().content_aware_computation_sorting()) {
859     SortComputationsByContent(&result);
860   }
861   return result;
862 }
863 
Clone(const std::string & suffix) const864 std::unique_ptr<HloModule> HloModule::Clone(const std::string& suffix) const {
865   return Clone(config(), suffix);
866 }
867 
Clone(const HloModuleConfig & config,const std::string & suffix) const868 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
869                                             const std::string& suffix) const {
870   VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
871   auto module = absl::WrapUnique(new HloModule(
872       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config,
873       std::make_unique<CompilationEnvironments>(*comp_envs_)));
874 
875   HloCloneContext context(module.get(), suffix);
876   auto cloned_computation = entry_computation_->Clone(suffix, &context);
877   module->AddEntryComputation(std::move(cloned_computation));
878   module->input_output_alias_config() = input_output_alias_config();
879   module->set_is_dynamic(is_dynamic());
880   if (has_schedule() && schedule().Verify().ok()) {
881     HloSchedule clone_schedule(module.get());
882     for (HloComputation* computation : computations()) {
883       if (schedule().is_computation_scheduled(computation)) {
884         HloComputation* new_computation = context.FindComputation(computation);
885         // The module being cloned may have computations that are dead, i.e.,
886         // unreachable from the entry computation. In that case, new_computation
887         // is nullptr.
888         if (new_computation != nullptr) {
889           HloInstructionSequence& clone_sequence =
890               clone_schedule.GetOrCreateSequence(new_computation);
891           for (const HloInstruction* instruction :
892                schedule().sequence(computation).instructions()) {
893             clone_sequence.push_back(context.GetInstruction(instruction));
894           }
895         }
896       }
897     }
898     TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
899   }
900   for (const auto& parameter_indices : CrossProgramPrefetches()) {
901     const auto& parameter = parameter_indices.first;
902     const auto& indices = parameter_indices.second;
903     module->AddCrossProgramPrefetch(parameter, indices);
904   }
905 
906   // To make clone behavior match uncloned behavior, we reorder
907   // module->computations_ to match the order in computations_.
908   using ComputationSorter = MappedPtrContainerSorter<HloComputation>;
909   ComputationSorter::MapPtrFn computation_map_fn =
910       [&context](const HloComputation* c) {
911         return context.FindComputation(c);
912       };
913   auto status = ComputationSorter::Sort(
914       computation_map_fn, ComputationSorter::IndexAfterMappedElementsFn(),
915       computations_, module->computations_);
916   if (!status.ok()) {
917     LOG(ERROR) << "Failed to sort module computations for " << name() << "; "
918                << status;
919   }
920 
921   return module;
922 }
923 
RemoveUnusedComputations()924 Status HloModule::RemoveUnusedComputations() {
925   std::string suffix = "tmp";
926   auto module = std::make_unique<HloModule>(
927       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
928   HloCloneContext context(module.get(), suffix);
929   entry_computation_->Clone(suffix, &context);
930   std::vector<HloComputation*> to_remove;
931   for (auto computation : computations()) {
932     auto found_computation = context.FindComputation(computation);
933     if (found_computation == nullptr) {
934       to_remove.push_back(computation);
935     }
936   }
937   for (auto computation : to_remove) {
938     TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
939   }
940   return OkStatus();
941 }
942 
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)943 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
944                                                 HloCloneContext* context) {
945   HloComputation* new_computation;
946   if (context != nullptr) {
947     if ((new_computation = context->FindComputation(computation)) != nullptr) {
948       return new_computation;
949     }
950     new_computation =
951         AddEmbeddedComputation(computation->Clone(context->suffix(), context));
952   } else {
953     new_computation = AddEmbeddedComputation(computation->Clone(""));
954   }
955   return new_computation;
956 }
957 
RandomNew64() const958 uint64_t HloModule::RandomNew64() const {
959   absl::MutexLock l(&rng_mutex_);
960   return rng_();
961 }
962 
GetComputationWithName(absl::string_view name)963 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
964   auto computations_in_module = computations();
965   auto it = absl::c_find_if(
966       computations_in_module,
967       [&](HloComputation* computation) { return computation->name() == name; });
968   return it == computations_in_module.end() ? nullptr : *it;
969 }
970 
HloModule(const std::string & name,HloModuleConfig config,std::unique_ptr<CompilationEnvironments> comp_envs)971 HloModule::HloModule(const std::string& name, HloModuleConfig config,
972                      std::unique_ptr<CompilationEnvironments> comp_envs)
973     : name_(NameUniquer::GetSanitizedName(name)),
974       config_(std::move(config)),
975       unique_id_(next_unique_module_id_++),
976       metadata_(tensorflow::Env::Default()),
977       comp_envs_(std::move(comp_envs)) {
978   metadata_.set_canonical_module_id(unique_id_);
979 }
980 
981 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
982 
983 }  // namespace xla
984