• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <iterator>
19 #include <set>
20 #include <sstream>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 #include "tensorflow/core/lib/gtl/map_util.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 namespace xla {
41 
HloModule(const string & name,HloModuleConfig config)42 HloModule::HloModule(const string& name, HloModuleConfig config)
43     : name_(NameUniquer::GetSanitizedName(name)),
44       config_(std::move(config)),
45       unique_id_(next_unique_module_id_++) {}
46 
set_schedule(HloSchedule schedule)47 Status HloModule::set_schedule(HloSchedule schedule) {
48   TF_RET_CHECK(schedule.module() == this);
49   TF_RETURN_IF_ERROR(schedule.Verify());
50   schedule_ = std::move(schedule);
51   return Status::OK();
52 }
53 
ReplaceEntryComputation(HloComputation * entry_computation)54 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
55   entry_computation_ = entry_computation;
56   config_.SetDefaultComputationLayout(
57       entry_computation_->ComputeProgramShape());
58   input_output_alias_config_ = HloInputOutputAliasConfig(
59       entry_computation_->root_instruction()->shape());
60 }
61 
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers)62 HloComputation* HloModule::AddComputationInternal(
63     std::unique_ptr<HloComputation> computation, bool is_entry,
64     bool uniquify_identifiers) {
65   if (is_entry) {
66     CHECK_EQ(nullptr, entry_computation_);
67     entry_computation_ = computation.get();
68 
69     // If the module configuration has no entry layout computation set, create a
70     // default one based on the program shape.
71     if (!config_.has_entry_computation_layout()) {
72       config_.SetDefaultComputationLayout(
73           entry_computation_->ComputeProgramShape());
74     }
75     input_output_alias_config_ = HloInputOutputAliasConfig(
76         entry_computation_->root_instruction()->shape());
77   }
78 
79   if (uniquify_identifiers) {
80     computation->UniquifyName(&computation_name_uniquer_);
81     for (auto* instruction : computation->instructions()) {
82       instruction->UniquifyName(&instruction_name_uniquer_);
83     }
84 
85     // Pick unique IDs for each instruction.
86     for (auto* instruction : computation->instructions()) {
87       instruction->SetUniqueId(NewUniqueInstructionId());
88     }
89     // Set unique id to this computation.
90     CHECK_NE(computation->root_instruction()->unique_id(), -1)
91         << "Root has no valid id: " << computation->ToString();
92     computation->SetUniqueId(computation->root_instruction()->unique_id());
93   } else {
94     // Don't uniquify the names of the computation or instruction, but we must
95     // run the names through the uniquifiers to prevent future name collisions
96     // for computations and instructions created later. Also, set the
97     // next_unique_id_ to the one greater than the max unique id of any
98     // instruction (or the computation) to avoid ID collisions.
99     computation_name_uniquer_.GetUniqueName(computation->name());
100     for (auto* instruction : computation->instructions()) {
101       instruction_name_uniquer_.GetUniqueName(instruction->name());
102       next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
103     }
104     if (next_unique_id_ < computation->unique_id() + 1) {
105       next_unique_id_ = computation->unique_id() + 1;
106     }
107   }
108 
109   computation->set_parent(this);
110   computations_.push_back(std::move(computation));
111   return computations_.back().get();
112 }
113 
AddEntryComputation(std::unique_ptr<HloComputation> computation)114 HloComputation* HloModule::AddEntryComputation(
115     std::unique_ptr<HloComputation> computation) {
116   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
117                                 /*uniquify_identifiers=*/true);
118 }
119 
RemoveEmbeddedComputation(HloComputation * to_remove)120 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
121   if (has_schedule() && !to_remove->IsFusionComputation()) {
122     schedule_->remove_computation(to_remove);
123   }
124 
125   auto it = absl::c_find_if(
126       computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
127         return comp.get() == to_remove;
128       });
129   TF_RET_CHECK(it != computations_.end());
130   TF_RET_CHECK(it->get() == to_remove);
131   computations_.erase(it);
132   return Status::OK();
133 }
134 
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)135 HloComputation* HloModule::AddEmbeddedComputation(
136     std::unique_ptr<HloComputation> computation) {
137   return AddComputationInternal(std::move(computation), /*is_entry=*/false,
138                                 /*uniquify_identifiers=*/true);
139 }
140 
ReplaceComputations(const std::unordered_map<HloComputation *,HloComputation * > & replacements)141 void HloModule::ReplaceComputations(
142     const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
143   // Replace all uses of non-canonical computations with their
144   // representatives.
145   std::vector<std::unique_ptr<HloComputation>> new_computations;
146   new_computations.reserve(computations_.size());
147 
148   for (std::unique_ptr<HloComputation>& computation : computations_) {
149     for (auto* instruction : computation->instructions()) {
150       switch (instruction->opcode()) {
151         case HloOpcode::kAllReduce:
152         case HloOpcode::kCall:
153         case HloOpcode::kMap:
154         case HloOpcode::kReduce:
155         case HloOpcode::kReduceWindow:
156         case HloOpcode::kScatter:
157         case HloOpcode::kSort: {
158           HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
159               replacements, instruction->to_apply(), nullptr);
160           if (new_arg != nullptr) {
161             instruction->set_to_apply(new_arg);
162           }
163           break;
164         }
165         case HloOpcode::kWhile: {
166           HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
167               replacements, instruction->while_condition(), nullptr);
168           if (new_condition != nullptr) {
169             instruction->set_while_condition(new_condition);
170           }
171           HloComputation* new_body = tensorflow::gtl::FindWithDefault(
172               replacements, instruction->while_body(), nullptr);
173           if (new_body != nullptr) {
174             instruction->set_while_body(new_body);
175           }
176           break;
177         }
178         case HloOpcode::kConditional: {
179           for (int b = 0; b < instruction->branch_count(); ++b) {
180             HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
181                 replacements, instruction->branch_computation(b), nullptr);
182             if (new_computation != nullptr) {
183               instruction->set_branch_computation(b, new_computation);
184             }
185           }
186           break;
187         }
188         case HloOpcode::kSelectAndScatter: {
189           HloComputation* new_select = tensorflow::gtl::FindWithDefault(
190               replacements, instruction->select(), nullptr);
191           if (new_select != nullptr) {
192             instruction->set_select(new_select);
193           }
194           HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
195               replacements, instruction->scatter(), nullptr);
196           if (new_scatter != nullptr) {
197             instruction->set_scatter(new_scatter);
198           }
199           break;
200         }
201         default:
202           break;
203       }
204     }
205 
206     if (replacements.find(computation.get()) == replacements.end()) {
207       new_computations.push_back(std::move(computation));
208     }
209   }
210 
211   // Replace entry_computation if necessary.
212   entry_computation_ = tensorflow::gtl::FindWithDefault(
213       replacements, entry_computation_, entry_computation_);
214 
215   computations_ = std::move(new_computations);
216 }
217 
ToString(const HloPrintOptions & options) const218 string HloModule::ToString(const HloPrintOptions& options) const {
219   std::ostringstream s;
220   s << "HloModule " << PrintName(name(), options.print_ids());
221   if (has_schedule()) {
222     TF_CHECK_OK(schedule().Verify());
223     s << ", is_scheduled=true";
224   }
225   s << "\n\n";
226   const auto& computations = options.canonicalize_computations()
227                                  ? MakeComputationSortedByContent()
228                                  : MakeComputationPostOrder();
229   for (const HloComputation* computation : computations) {
230     if (!options.print_computation(computation)) {
231       continue;
232     }
233     if (computation == entry_computation()) {
234       s << "ENTRY ";
235     }
236     if (has_schedule() && schedule().is_computation_scheduled(computation)) {
237       s << computation->ToString(
238                options, schedule().sequence(computation).instructions())
239         << "\n\n";
240     } else {
241       s << computation->ToString(options) << "\n\n";
242     }
243   }
244   return s.str();
245 }
246 
ToProto() const247 HloModuleProto HloModule::ToProto() const {
248   HloModuleProto proto;
249   proto.set_id(unique_id_);
250   proto.set_name(name_);
251   proto.set_entry_computation_name(entry_computation_->name());
252   proto.set_entry_computation_id(entry_computation_->unique_id());
253   for (const HloComputation* computation : MakeComputationPostOrder()) {
254     HloComputationProto computation_proto = computation->ToProto();
255     proto.add_computations()->Swap(&computation_proto);
256   }
257   if (has_schedule()) {
258     *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
259   }
260   *proto.mutable_host_program_shape() =
261       entry_computation_layout().ComputeProgramShape().ToProto();
262   *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
263   *proto.mutable_dynamic_parameter_binding() =
264       dynamic_parameter_binding().ToProto();
265   return proto;
266 }
267 
CheckUniqueNamesAndIdsForComputationsAndInstructions() const268 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
269   absl::flat_hash_set<string> computation_names;
270   absl::flat_hash_set<int> computation_ids;
271   absl::flat_hash_set<string> instruction_names;
272   absl::flat_hash_set<int> instruction_ids;
273 
274   for (const HloComputation* computation : computations()) {
275     TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
276         << "Computation name is not unique: " << computation->name();
277     computation_names.insert(computation->name());
278 
279     TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
280         << "Computation id is not unique: " << computation->unique_id();
281     computation_ids.insert(computation->unique_id());
282 
283     for (const HloInstruction* instruction : computation->instructions()) {
284       TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
285           << "Instruction name is not unique: " << instruction->name();
286       instruction_names.insert(instruction->name());
287 
288       TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
289           << "Instruction id is not unique: " << instruction->unique_id();
290       instruction_ids.insert(instruction->unique_id());
291     }
292   }
293   return Status::OK();
294 }
295 
296 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)297 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
298     const HloModuleProto& proto, const HloModuleConfig& module_config,
299     bool prohibit_empty_literal) {
300   VLOG(2) << "CreateFromProto()";
301   XLA_VLOG_LINES(3, proto.DebugString());
302 
303   // The ProgramShape in the passed in module config must match the shapes of
304   // the entry parameters and root.
305   TF_RET_CHECK(proto.has_host_program_shape())
306       << "No program shape found in the proto";
307   ProgramShape expected_program_shape(proto.host_program_shape());
308   TF_RET_CHECK(expected_program_shape.parameters_size() ==
309                module_config.entry_computation_layout().parameter_count());
310   for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
311     const Shape& parameter_shape =
312         module_config.entry_computation_layout().parameter_layout(i).shape();
313     TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
314                                        parameter_shape))
315         << "HloModuleConfig has different shape for parameter " << i
316         << " than the HLO module. Expected: "
317         << ShapeUtil::HumanStringWithLayout(
318                expected_program_shape.parameters(i))
319         << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
320   }
321   const Shape& result_shape =
322       module_config.entry_computation_layout().result_layout().shape();
323   TF_RET_CHECK(
324       ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
325       << "HloModuleConfig has different result shape than the HLO module. "
326          "Expected: "
327       << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
328       << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
329 
330   absl::flat_hash_map<int64, HloComputation*> computation_map;
331   absl::flat_hash_map<HloComputation*, int64> to_proto_id;
332   std::vector<std::unique_ptr<HloComputation>> computations;
333   HloComputation* entry = nullptr;
334   for (const HloComputationProto& computation_proto : proto.computations()) {
335     TF_ASSIGN_OR_RETURN(
336         std::unique_ptr<HloComputation> computation,
337         HloComputation::CreateFromProto(computation_proto, computation_map,
338                                         prohibit_empty_literal));
339     CHECK_NE(computation.get(), nullptr);
340     int64 computation_id = computation_proto.id();
341     TF_RET_CHECK(computation_id != -1);
342     TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
343     computation_map[computation_id] = computation.get();
344     to_proto_id[computation.get()] = computation_id;
345     if (computation_id == proto.entry_computation_id()) {
346       entry = computation.get();
347     }
348     computations.push_back(std::move(computation));
349   }
350   TF_RET_CHECK(entry != nullptr);
351 
352   auto module = absl::make_unique<HloModule>(proto.name(), module_config);
353 
354   // Sort the computations in the proto id's order.
355   absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
356                                  const std::unique_ptr<HloComputation>& b) {
357     return to_proto_id[a.get()] < to_proto_id[b.get()];
358   });
359 
360   // Add sorted computations to the module.
361   for (auto& computation : computations) {
362     bool is_entry = computation.get() == entry;
363     // Don't uniquify names because we want names to be stable across
364     // serialization and deserialization.
365     module->AddComputationInternal(std::move(computation), is_entry,
366                                    /*uniquify_identifiers=*/false);
367   }
368   TF_RET_CHECK(module->entry_computation_ != nullptr);
369 
370   TF_ASSIGN_OR_RETURN(
371       module->input_output_alias_config_,
372       HloInputOutputAliasConfig::CreateFromProto(
373           entry->ComputeProgramShape().result(), proto.input_output_alias()));
374 
375   // Because we didn't uniquify the names or the ids, double-check that the
376   // instruction and computation names and ids are unique from the proto.
377   TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
378                       DynamicParameterBinding::CreateFromProto(
379                           proto.dynamic_parameter_binding()));
380 
381   TF_RETURN_IF_ERROR(
382       module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
383 
384   if (proto.has_schedule()) {
385     TF_ASSIGN_OR_RETURN(
386         HloSchedule schedule,
387         HloSchedule::CreateFromProto(module.get(), proto.schedule()));
388     TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
389   }
390 
391   return std::move(module);
392 }
393 
394 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)395 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
396     const ProgramShape& program_shape, const DebugOptions& debug_options,
397     const ExecutionOptions* execution_options) {
398   HloModuleConfig module_config(ProgramShape{program_shape});
399   module_config.set_debug_options(debug_options);
400   if (execution_options) {
401     if (execution_options->num_replicas() > 0) {
402       module_config.set_replica_count(execution_options->num_replicas());
403     }
404     if (execution_options->num_partitions() > 0) {
405       module_config.set_num_partitions(execution_options->num_partitions());
406     }
407     if (execution_options->has_device_assignment()) {
408       TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
409                           DeviceAssignment::Deserialize(
410                               execution_options->device_assignment()));
411       module_config.set_static_device_assignment(*device_assignment);
412       if (execution_options->num_replicas() > 0) {
413         CHECK_EQ(module_config.static_device_assignment().replica_count(),
414                  module_config.replica_count());
415       }
416       if (execution_options->num_partitions() > 0) {
417         CHECK_EQ(module_config.static_device_assignment().computation_count(),
418                  module_config.num_partitions());
419       }
420     }
421   }
422 
423   // The module config is constructed with default layouts regardless of what is
424   // passed in via the ProgramShape. Set the layouts to the appropriate values.
425   ComputationLayout* entry_layout =
426       module_config.mutable_entry_computation_layout();
427   for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
428     TF_RETURN_IF_ERROR(
429         entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
430             program_shape.parameters(i)));
431   }
432   TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
433       program_shape.result()));
434   return module_config;
435 }
436 
437 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)438 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
439     const HloModuleProto& module, const DebugOptions& debug_options,
440     const ExecutionOptions* execution_options) {
441   TF_RET_CHECK(module.has_host_program_shape())
442       << "No program shape found in the proto";
443   ProgramShape program_shape(module.host_program_shape());
444   return CreateModuleConfigFromShape(program_shape, debug_options,
445                                      execution_options);
446 }
447 
448 namespace {
449 // Returns whether `hlo` is used outside the given subcomputation.
450 // `instructions_in_subcomputation` is the instruction set of the given
451 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)452 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
453                                  const absl::flat_hash_set<HloInstruction*>&
454                                      instructions_in_subcomputation) {
455   return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
456     return !instructions_in_subcomputation.contains(user);
457   });
458 }
459 }  // anonymous namespace
460 
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)461 HloInstruction* HloModule::OutlineExpressionFromComputation(
462     absl::Span<HloInstruction* const> instructions_to_outline,
463     const string& outlined_computation_name, HloComputation* computation) {
464   auto builder = HloComputation::Builder(outlined_computation_name);
465 
466   // A map from original instructions to their counterparts in the new outlined
467   // function.
468   absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
469   // A set that contains all instructions to be outlined.
470   absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
471       instructions_to_outline.begin(), instructions_to_outline.end());
472   std::vector<HloInstruction*> arguments;
473   std::vector<HloInstruction*> outputs;
474   int64 parameter_count = 0;
475   for (HloInstruction* instruction_to_outline : instructions_to_outline) {
476     // Clone the original instruction.
477     HloInstruction* outlined_instruction =
478         builder.AddInstruction(instruction_to_outline->Clone());
479 
480     // Replace its operands to their counterparts in the new function.
481     for (int64 operand_num = 0;
482          operand_num < outlined_instruction->operand_count(); ++operand_num) {
483       HloInstruction* old_operand =
484           outlined_instruction->mutable_operand(operand_num);
485 
486       HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
487       if (*operand_slot == nullptr) {
488         // Because instructions_to_outline is in topological order, if
489         // old_operand is not in outlined_instructions, old_operand must be an
490         // input of the outlined subcomputation and thus should be represented
491         // as a parameter in the new function.
492         arguments.push_back(old_operand);
493         *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
494             parameter_count, old_operand->shape(), "p"));
495         ++parameter_count;
496       }
497       TF_CHECK_OK(
498           outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
499     }
500 
501     // Insert the new instruction into the outlined_instructions map.
502     InsertOrDie(&outlined_instructions, instruction_to_outline,
503                 outlined_instruction);
504 
505     // Mark instruction_to_outline an output if it is used outside the
506     // subcomputation or is the output of the original computation (i.e. used
507     // externally).
508     if (instruction_to_outline->user_count() == 0 ||
509         IsUsedOutsideSubcomputation(*instruction_to_outline,
510                                     instruction_set_to_outline)) {
511       outputs.push_back(instruction_to_outline);
512     }
513   }
514 
515   if (outputs.size() != 1) {
516     string error_message =
517         "The subcomputation to outline has multiple outputs:\n";
518     for (HloInstruction* output : outputs) {
519       absl::StrAppend(&error_message, output->ToString(), "\n");
520     }
521     LOG(FATAL) << error_message;
522   }
523   HloInstruction* output = outputs[0];
524 
525   // Creates a call to the nested computation.
526   HloComputation* nested_computation = AddEmbeddedComputation(
527       builder.Build(FindOrDie(outlined_instructions, output)));
528   HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
529       output->shape(), arguments, nested_computation));
530 
531   VLOG(2) << "Outlining the following instructions";
532   for (auto* instruction_to_outline : instructions_to_outline) {
533     VLOG(2) << "  " << instruction_to_outline->ToString();
534   }
535   VLOG(2) << "as a call " << call->ToString();
536   VLOG(2) << "to " << nested_computation->ToString();
537 
538   TF_CHECK_OK(output->ReplaceAllUsesWith(call));
539   for (auto i = instructions_to_outline.rbegin();
540        i != instructions_to_outline.rend(); ++i) {
541     TF_CHECK_OK(computation->RemoveInstruction(*i));
542   }
543 
544   return call;
545 }
546 
instruction_count() const547 int64 HloModule::instruction_count() const {
548   int64 n = 0;
549   for (const auto& computation : computations_) {
550     n += computation->instruction_count();
551   }
552   return n;
553 }
554 
MakeComputationPostOrder() const555 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
556   // First determine all root computations by building a set of nonroot
557   // computations (computations which are called by an instruction in the
558   // module).
559   absl::flat_hash_set<HloComputation*> nonroot_computations;
560   for (auto& computation : computations_) {
561     for (auto* instruction : computation->instructions()) {
562       for (HloComputation* called_computation :
563            instruction->called_computations()) {
564         nonroot_computations.insert(called_computation);
565       }
566     }
567   }
568 
569   // Keep track of computations which have already been added to the post
570   // order. This prevents duplication as an embedded computation may be called
571   // from two different root computations.
572   absl::flat_hash_set<HloComputation*> added_computations;
573   std::vector<HloComputation*> post_order;
574   for (auto& computation : computations_) {
575     if (!nonroot_computations.contains(computation.get())) {
576       for (HloComputation* embedded_computation :
577            computation->MakeEmbeddedComputationsList()) {
578         if (!added_computations.contains(embedded_computation)) {
579           post_order.push_back(embedded_computation);
580           added_computations.insert(embedded_computation);
581         }
582       }
583       // Root computations should only be encountered once.
584       CHECK(!added_computations.contains(computation.get()));
585       post_order.push_back(computation.get());
586       added_computations.insert(computation.get());
587     }
588   }
589   if (post_order.size() != computations_.size()) {
590     for (HloComputation* computation : post_order) {
591       LOG(ERROR) << "Post Order: " << computation->name() << " ("
592                  << computation->parent()->name() << ")";
593     }
594     for (auto& computation : computations_) {
595       LOG(ERROR) << "Computations: " << computation->name() << " ("
596                  << computation->parent()->name() << ")";
597     }
598     LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
599                << " computation_count=" << computations_.size();
600   }
601   return post_order;
602 }
603 
MakeComputationSortedByContent() const604 std::vector<HloComputation*> HloModule::MakeComputationSortedByContent() const {
605   auto result = MakeComputationPostOrder();
606   std::sort(result.begin(), result.end(),
607             [](HloComputation* a, HloComputation* b) {
608               if (a->instruction_count() != b->instruction_count()) {
609                 return a->instruction_count() < b->instruction_count();
610               }
611               return a->ToString(HloPrintOptions::Fingerprint()) <
612                      b->ToString(HloPrintOptions::Fingerprint());
613             });
614   return result;
615 }
616 
MakeNonfusionComputations() const617 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
618   std::vector<HloComputation*> result;
619   for (auto* c : computations()) {
620     if (c->IsFusionComputation()) {
621       continue;
622     }
623     result.push_back(c);
624   }
625   return result;
626 }
627 
MakeNonfusionComputationsSorted() const628 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
629     const {
630   auto result = MakeNonfusionComputations();
631   std::sort(result.begin(), result.end(),
632             [](HloComputation* a, HloComputation* b) {
633               return a->name() < b->name();
634             });
635   return result;
636 }
637 
Clone(const string & suffix) const638 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
639   return Clone(config(), suffix);
640 }
641 
Clone(const HloModuleConfig & config,const string & suffix) const642 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
643                                             const string& suffix) const {
644   VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
645   auto module = absl::make_unique<HloModule>(
646       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
647 
648   HloCloneContext context(module.get(), suffix);
649   auto cloned_computation = entry_computation_->Clone(suffix, &context);
650   module->AddEntryComputation(std::move(cloned_computation));
651 
652   if (has_schedule() && schedule().Verify().ok()) {
653     HloSchedule clone_schedule(module.get());
654     for (HloComputation* computation : computations()) {
655       if (schedule().is_computation_scheduled(computation)) {
656         HloInstructionSequence& clone_sequence =
657             clone_schedule.GetOrCreateSequence(
658                 context.GetComputation(computation));
659         for (const HloInstruction* instruction :
660              schedule().sequence(computation).instructions()) {
661           clone_sequence.push_back(context.GetInstruction(instruction));
662         }
663       }
664     }
665     TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
666   }
667   return module;
668 }
669 
RemoveUnusedComputations()670 Status HloModule::RemoveUnusedComputations() {
671   std::string suffix = "tmp";
672   auto module = absl::make_unique<HloModule>(
673       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
674   HloCloneContext context(module.get(), suffix);
675   entry_computation_->Clone(suffix, &context);
676   std::vector<HloComputation*> to_remove;
677   for (auto computation : computations()) {
678     auto found_computation = context.FindComputation(computation);
679     if (found_computation == nullptr) {
680       to_remove.push_back(computation);
681     }
682   }
683   for (auto computation : to_remove) {
684     TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
685   }
686   return Status::OK();
687 }
688 
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)689 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
690                                                 HloCloneContext* context) {
691   HloComputation* new_computation;
692   if (context != nullptr) {
693     if ((new_computation = context->FindComputation(computation)) != nullptr) {
694       return new_computation;
695     }
696     new_computation =
697         AddEmbeddedComputation(computation->Clone(context->suffix(), context));
698   } else {
699     new_computation = AddEmbeddedComputation(computation->Clone(""));
700   }
701   return new_computation;
702 }
703 
RandomNew64() const704 uint64 HloModule::RandomNew64() const {
705   tensorflow::mutex_lock l(rng_mutex_);
706   return rng_();
707 }
708 
GetComputationWithName(absl::string_view name)709 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
710   auto computations_in_module = computations();
711   auto it = absl::c_find_if(
712       computations_in_module,
713       [&](HloComputation* computation) { return computation->name() == name; });
714   return it == computations_in_module.end() ? nullptr : *it;
715 }
716 
Hash() const717 uint64 HloModule::Hash() const {
718   return tensorflow::Hash64Combine(
719       entry_computation_layout().Hash(),
720       entry_computation()->root_instruction()->Hash());
721 }
722 
723 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
724 
725 }  // namespace xla
726