• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <set>
21 #include <sstream>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_computation.h"
33 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
34 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/types.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/lib/hash/hash.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/stacktrace.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace xla {
45 
HloModule(const string & name,HloModuleConfig config)46 HloModule::HloModule(const string& name, HloModuleConfig config)
47     : name_(NameUniquer::GetSanitizedName(name)),
48       config_(std::move(config)),
49       unique_id_(next_unique_module_id_++),
50       metadata_(tensorflow::Env::Default()) {
51   metadata_.set_canonical_module_id(unique_id_);
52 }
53 
set_schedule(HloSchedule schedule)54 Status HloModule::set_schedule(HloSchedule schedule) {
55   TF_RET_CHECK(schedule.module() == this);
56   TF_RETURN_IF_ERROR(schedule.Verify());
57   schedule_ = std::move(schedule);
58   return Status::OK();
59 }
60 
ReplaceEntryComputation(HloComputation * entry_computation)61 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
62   entry_computation_ = entry_computation;
63   config_.SetDefaultComputationLayout(
64       entry_computation_->ComputeProgramShape());
65   input_output_alias_config_ = HloInputOutputAliasConfig(
66       entry_computation_->root_instruction()->shape());
67 }
68 
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers,bool preserve_entry_layouts)69 HloComputation* HloModule::AddComputationInternal(
70     std::unique_ptr<HloComputation> computation, bool is_entry,
71     bool uniquify_identifiers, bool preserve_entry_layouts) {
72   if (is_entry) {
73     CHECK_EQ(nullptr, entry_computation_);
74     entry_computation_ = computation.get();
75 
76     if (preserve_entry_layouts) {
77       config_.SetComputationLayoutIfExists(
78           entry_computation_->ComputeProgramShape());
79     } else if (!config_.has_entry_computation_layout()) {
80       // If the module configuration has no entry layout computation set, create
81       // a default one based on the program shape.
82       config_.SetDefaultComputationLayout(
83           entry_computation_->ComputeProgramShape());
84     }
85     input_output_alias_config_ = HloInputOutputAliasConfig(
86         entry_computation_->root_instruction()->shape());
87   }
88 
89   if (uniquify_identifiers) {
90     computation->UniquifyName(&computation_name_uniquer_);
91     for (auto* instruction : computation->instructions()) {
92       instruction->UniquifyName(&instruction_name_uniquer_);
93     }
94 
95     // Pick unique IDs for each instruction.
96     for (auto* instruction : computation->instructions()) {
97       instruction->SetUniqueId(NewUniqueInstructionId());
98     }
99     // Set unique id to this computation.
100     CHECK_NE(computation->root_instruction()->unique_id(), -1)
101         << "Root has no valid id: " << computation->ToString();
102     computation->SetUniqueId(computation->root_instruction()->unique_id());
103   } else {
104     // Don't uniquify the names of the computation or instruction, but we must
105     // run the names through the uniquifiers to prevent future name collisions
106     // for computations and instructions created later. Also, set the
107     // next_unique_id_ to the one greater than the max unique id of any
108     // instruction (or the computation) to avoid ID collisions.
109     computation_name_uniquer_.GetUniqueName(computation->name());
110     for (auto* instruction : computation->instructions()) {
111       instruction_name_uniquer_.GetUniqueName(instruction->name());
112       next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
113     }
114     if (next_unique_id_ < computation->unique_id() + 1) {
115       next_unique_id_ = computation->unique_id() + 1;
116     }
117   }
118 
119   computation->set_parent(this);
120   computations_.push_back(std::move(computation));
121   return computations_.back().get();
122 }
123 
AddEntryComputation(std::unique_ptr<HloComputation> computation)124 HloComputation* HloModule::AddEntryComputation(
125     std::unique_ptr<HloComputation> computation) {
126   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
127                                 /*uniquify_identifiers=*/true,
128                                 /*preserve_entry_layouts=*/false);
129 }
130 
AddEntryComputationWithLayouts(std::unique_ptr<HloComputation> computation)131 HloComputation* HloModule::AddEntryComputationWithLayouts(
132     std::unique_ptr<HloComputation> computation) {
133   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
134                                 /*uniquify_identifiers=*/true,
135                                 /*preserve_entry_layouts=*/true);
136 }
137 
RemoveEmbeddedComputation(HloComputation * to_remove)138 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
139   if (has_schedule() && !to_remove->IsCalledComputation()) {
140     schedule_->remove_computation(to_remove);
141   }
142 
143   auto it = absl::c_find_if(
144       computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
145         return comp.get() == to_remove;
146       });
147   TF_RET_CHECK(it != computations_.end());
148   TF_RET_CHECK(it->get() == to_remove);
149   computations_.erase(it);
150   return Status::OK();
151 }
152 
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)153 HloComputation* HloModule::AddEmbeddedComputation(
154     std::unique_ptr<HloComputation> computation) {
155   return AddComputationInternal(std::move(computation), /*is_entry=*/false,
156                                 /*uniquify_identifiers=*/true,
157                                 /*preserve_entry_layouts=*/false);
158 }
159 
ReplaceComputations(const absl::flat_hash_map<HloComputation *,HloComputation * > & replacements)160 void HloModule::ReplaceComputations(
161     const absl::flat_hash_map<HloComputation*, HloComputation*>& replacements) {
162   // Replace all uses of non-canonical computations with their
163   // representatives.
164   std::vector<std::unique_ptr<HloComputation>> new_computations;
165   new_computations.reserve(computations_.size());
166 
167   for (std::unique_ptr<HloComputation>& computation : computations_) {
168     for (auto* instruction : computation->instructions()) {
169       switch (instruction->opcode()) {
170         case HloOpcode::kAllReduce:
171         case HloOpcode::kCall:
172         case HloOpcode::kMap:
173         case HloOpcode::kReduce:
174         case HloOpcode::kReduceScatter:
175         case HloOpcode::kReduceWindow:
176         case HloOpcode::kScatter:
177         case HloOpcode::kSort: {
178           HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
179               replacements, instruction->to_apply(), nullptr);
180           if (new_arg != nullptr) {
181             instruction->set_to_apply(new_arg);
182           }
183           break;
184         }
185         case HloOpcode::kWhile: {
186           HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
187               replacements, instruction->while_condition(), nullptr);
188           if (new_condition != nullptr) {
189             instruction->set_while_condition(new_condition);
190           }
191           HloComputation* new_body = tensorflow::gtl::FindWithDefault(
192               replacements, instruction->while_body(), nullptr);
193           if (new_body != nullptr) {
194             instruction->set_while_body(new_body);
195           }
196           break;
197         }
198         case HloOpcode::kConditional: {
199           for (int b = 0; b < instruction->branch_count(); ++b) {
200             HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
201                 replacements, instruction->branch_computation(b), nullptr);
202             if (new_computation != nullptr) {
203               instruction->set_branch_computation(b, new_computation);
204             }
205           }
206           break;
207         }
208         case HloOpcode::kSelectAndScatter: {
209           HloComputation* new_select = tensorflow::gtl::FindWithDefault(
210               replacements, instruction->select(), nullptr);
211           if (new_select != nullptr) {
212             instruction->set_select(new_select);
213           }
214           HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
215               replacements, instruction->scatter(), nullptr);
216           if (new_scatter != nullptr) {
217             instruction->set_scatter(new_scatter);
218           }
219           break;
220         }
221         default:
222           break;
223       }
224     }
225 
226     if (replacements.find(computation.get()) == replacements.end()) {
227       new_computations.push_back(std::move(computation));
228     }
229   }
230 
231   // Replace entry_computation if necessary.
232   entry_computation_ = tensorflow::gtl::FindWithDefault(
233       replacements, entry_computation_, entry_computation_);
234 
235   computations_ = std::move(new_computations);
236 }
237 
ToString(const HloPrintOptions & options) const238 string HloModule::ToString(const HloPrintOptions& options) const {
239   std::ostringstream s;
240   // When print_ids() is false, exclude module's name because it includes and
241   // leads to non-deterministic fingerprint.
242   s << "HloModule "
243     << (options.print_ids() ? PrintName(name(), options.print_ids()) : "");
244   if (has_schedule()) {
245     TF_CHECK_OK(schedule().Verify());
246     s << ", is_scheduled=true";
247   }
248   std::string serialized_aliasing = input_output_alias_config().ToShortString();
249   if (!serialized_aliasing.empty()) {
250     s << absl::StrFormat(", input_output_alias={ %s }", serialized_aliasing);
251   }
252   if (config_.alias_passthrough_params()) {
253     s << ", alias_passthrough_params=true";
254   }
255   s << "\n\n";
256   const auto& computations = options.canonicalize_computations()
257                                  ? MakeComputationSorted()
258                                  : MakeComputationPostOrder();
259   for (const HloComputation* computation : computations) {
260     if (!options.print_computation(computation)) {
261       continue;
262     }
263     if (computation == entry_computation()) {
264       s << "ENTRY ";
265     }
266     if (has_schedule() && schedule().is_computation_scheduled(computation)) {
267       s << computation->ToString(
268                options, schedule().sequence(computation).instructions())
269         << "\n\n";
270     } else {
271       s << computation->ToString(options) << "\n\n";
272     }
273   }
274   return s.str();
275 }
276 
ToProto() const277 HloModuleProto HloModule::ToProto() const {
278   HloModuleProto proto;
279   proto.set_id(unique_id_);
280   proto.set_name(name_);
281   proto.set_entry_computation_name(entry_computation_->name());
282   proto.set_entry_computation_id(entry_computation_->unique_id());
283   for (const HloComputation* computation : MakeComputationPostOrder()) {
284     HloComputationProto computation_proto = computation->ToProto();
285     proto.add_computations()->Swap(&computation_proto);
286   }
287   if (has_schedule()) {
288     *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
289   }
290   *proto.mutable_host_program_shape() =
291       entry_computation_layout().ComputeProgramShape().ToProto();
292   *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
293   *proto.mutable_dynamic_parameter_binding() =
294       dynamic_parameter_binding().ToProto();
295   for (const auto& parameter_indices : CrossProgramPrefetches()) {
296     const auto& parameter = parameter_indices.first;
297     const auto& indices = parameter_indices.second;
298     auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
299     prefetch->set_parameter(parameter);
300     for (auto index : indices) {
301       prefetch->add_index(index);
302     }
303   }
304   proto.set_is_dynamic(is_dynamic_);
305   return proto;
306 }
307 
CheckUniqueNamesAndIdsForComputationsAndInstructions() const308 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
309   absl::flat_hash_set<string> computation_names;
310   absl::flat_hash_set<int> computation_ids;
311   absl::flat_hash_set<string> instruction_names;
312   absl::flat_hash_set<int> instruction_ids;
313 
314   for (const HloComputation* computation : computations()) {
315     TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
316         << "Computation name is not unique: " << computation->name();
317     computation_names.insert(computation->name());
318 
319     TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
320         << "Computation id is not unique: " << computation->unique_id();
321     computation_ids.insert(computation->unique_id());
322 
323     for (const HloInstruction* instruction : computation->instructions()) {
324       TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
325           << "Instruction name is not unique: " << instruction->name();
326       instruction_names.insert(instruction->name());
327 
328       TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
329           << "Instruction id is not unique: " << instruction->unique_id();
330       instruction_ids.insert(instruction->unique_id());
331     }
332   }
333   return Status::OK();
334 }
335 
336 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)337 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
338     const HloModuleProto& proto, const HloModuleConfig& module_config,
339     bool prohibit_empty_literal) {
340   VLOG(2) << "CreateFromProto()";
341   XLA_VLOG_LINES(3, proto.DebugString());
342 
343   // The ProgramShape in the passed in module config must match the shapes of
344   // the entry parameters and root.
345   TF_RET_CHECK(proto.has_host_program_shape())
346       << "No program shape found in the proto";
347   ProgramShape expected_program_shape(proto.host_program_shape());
348   TF_RET_CHECK(expected_program_shape.parameters_size() ==
349                module_config.entry_computation_layout().parameter_count());
350   for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
351     const Shape& parameter_shape =
352         module_config.entry_computation_layout().parameter_layout(i).shape();
353     TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
354                                        parameter_shape))
355         << "HloModuleConfig has different shape for parameter " << i
356         << " than the HLO module. Expected: "
357         << ShapeUtil::HumanStringWithLayout(
358                expected_program_shape.parameters(i))
359         << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
360   }
361   const Shape& result_shape =
362       module_config.entry_computation_layout().result_layout().shape();
363   TF_RET_CHECK(
364       ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
365       << "HloModuleConfig has different result shape than the HLO module. "
366          "Expected: "
367       << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
368       << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
369 
370   absl::flat_hash_map<int64, HloComputation*> computation_map;
371   absl::flat_hash_map<HloComputation*, int64> to_proto_id;
372   std::vector<std::unique_ptr<HloComputation>> computations;
373   HloComputation* entry = nullptr;
374   for (const HloComputationProto& computation_proto : proto.computations()) {
375     TF_ASSIGN_OR_RETURN(
376         std::unique_ptr<HloComputation> computation,
377         HloComputation::CreateFromProto(computation_proto, computation_map,
378                                         prohibit_empty_literal));
379     CHECK_NE(computation.get(), nullptr);
380     int64_t computation_id = computation_proto.id();
381     TF_RET_CHECK(computation_id != -1);
382     TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
383     computation_map[computation_id] = computation.get();
384     to_proto_id[computation.get()] = computation_id;
385     if (computation_id == proto.entry_computation_id()) {
386       entry = computation.get();
387     }
388     computations.push_back(std::move(computation));
389   }
390   TF_RET_CHECK(entry != nullptr);
391 
392   auto module = absl::make_unique<HloModule>(proto.name(), module_config);
393 
394   // Sort the computations in the proto id's order.
395   absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
396                                  const std::unique_ptr<HloComputation>& b) {
397     return to_proto_id[a.get()] < to_proto_id[b.get()];
398   });
399 
400   // Add sorted computations to the module.
401   for (auto& computation : computations) {
402     bool is_entry = computation.get() == entry;
403     // Don't uniquify names because we want names to be stable across
404     // serialization and deserialization.
405     module->AddComputationInternal(std::move(computation), is_entry,
406                                    /*uniquify_identifiers=*/false,
407                                    /*preserve_entry_layouts=*/false);
408   }
409   TF_RET_CHECK(module->entry_computation_ != nullptr);
410 
411   TF_ASSIGN_OR_RETURN(
412       module->input_output_alias_config_,
413       HloInputOutputAliasConfig::CreateFromProto(
414           entry->ComputeProgramShape().result(), proto.input_output_alias()));
415 
416   // Because we didn't uniquify the names or the ids, double-check that the
417   // instruction and computation names and ids are unique from the proto.
418   TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
419                       DynamicParameterBinding::CreateFromProto(
420                           proto.dynamic_parameter_binding()));
421 
422   TF_RETURN_IF_ERROR(
423       module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
424 
425   if (proto.has_schedule()) {
426     TF_ASSIGN_OR_RETURN(
427         HloSchedule schedule,
428         HloSchedule::CreateFromProto(module.get(), proto.schedule()));
429     TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
430   }
431 
432   for (auto prefetch : proto.cross_program_prefetches()) {
433     module->AddCrossProgramPrefetch(
434         prefetch.parameter(),
435         ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
436   }
437 
438   module->set_is_dynamic(proto.is_dynamic());
439 
440   return std::move(module);
441 }
442 
443 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)444 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
445     const ProgramShape& program_shape, const DebugOptions& debug_options,
446     const ExecutionOptions* execution_options) {
447   HloModuleConfig module_config(ProgramShape{program_shape});
448   module_config.set_debug_options(debug_options);
449   if (execution_options) {
450     if (execution_options->num_replicas() > 0) {
451       module_config.set_replica_count(execution_options->num_replicas());
452     }
453     if (execution_options->num_partitions() > 0) {
454       module_config.set_num_partitions(execution_options->num_partitions());
455     }
456     module_config.set_use_spmd_partitioning(
457         execution_options->use_spmd_partitioning());
458     module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
459     if (execution_options->has_device_assignment()) {
460       TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
461                           DeviceAssignment::Deserialize(
462                               execution_options->device_assignment()));
463       module_config.set_static_device_assignment(*device_assignment);
464       if (execution_options->num_replicas() > 0) {
465         CHECK_EQ(module_config.static_device_assignment().replica_count(),
466                  module_config.replica_count());
467       }
468       if (execution_options->num_partitions() > 0) {
469         CHECK_EQ(module_config.static_device_assignment().computation_count(),
470                  module_config.num_partitions());
471       }
472     }
473   }
474 
475   // The module config is constructed with default layouts regardless of what is
476   // passed in via the ProgramShape. Set the layouts to the appropriate values.
477   ComputationLayout* entry_layout =
478       module_config.mutable_entry_computation_layout();
479   for (int64_t i = 0; i < entry_layout->parameter_count(); ++i) {
480     TF_RETURN_IF_ERROR(
481         entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
482             program_shape.parameters(i)));
483   }
484   TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
485       program_shape.result()));
486   return module_config;
487 }
488 
489 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)490 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
491     const HloModuleProto& module, const DebugOptions& debug_options,
492     const ExecutionOptions* execution_options) {
493   TF_RET_CHECK(module.has_host_program_shape())
494       << "No program shape found in the proto";
495   ProgramShape program_shape(module.host_program_shape());
496   return CreateModuleConfigFromShape(program_shape, debug_options,
497                                      execution_options);
498 }
499 
500 namespace {
501 // Returns whether `hlo` is used outside the given subcomputation.
502 // `instructions_in_subcomputation` is the instruction set of the given
503 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)504 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
505                                  const absl::flat_hash_set<HloInstruction*>&
506                                      instructions_in_subcomputation) {
507   return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
508     return !instructions_in_subcomputation.contains(user);
509   });
510 }
511 }  // anonymous namespace
512 
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)513 HloInstruction* HloModule::OutlineExpressionFromComputation(
514     absl::Span<HloInstruction* const> instructions_to_outline,
515     const string& outlined_computation_name, HloComputation* computation) {
516   auto builder = HloComputation::Builder(outlined_computation_name);
517 
518   // A map from original instructions to their counterparts in the new outlined
519   // function.
520   absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
521   // A set that contains all instructions to be outlined.
522   absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
523       instructions_to_outline.begin(), instructions_to_outline.end());
524   std::vector<HloInstruction*> arguments;
525   std::vector<HloInstruction*> outputs;
526   int64_t parameter_count = 0;
527   for (HloInstruction* instruction_to_outline : instructions_to_outline) {
528     // Clone the original instruction.
529     HloInstruction* outlined_instruction =
530         builder.AddInstruction(instruction_to_outline->Clone());
531 
532     // Replace its operands to their counterparts in the new function.
533     for (int64_t operand_num = 0;
534          operand_num < outlined_instruction->operand_count(); ++operand_num) {
535       HloInstruction* old_operand =
536           outlined_instruction->mutable_operand(operand_num);
537 
538       HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
539       if (*operand_slot == nullptr) {
540         // Because instructions_to_outline is in topological order, if
541         // old_operand is not in outlined_instructions, old_operand must be an
542         // input of the outlined subcomputation and thus should be represented
543         // as a parameter in the new function.
544         arguments.push_back(old_operand);
545         *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
546             parameter_count, old_operand->shape(), "p"));
547         ++parameter_count;
548       }
549       TF_CHECK_OK(
550           outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
551     }
552 
553     // Insert the new instruction into the outlined_instructions map.
554     InsertOrDie(&outlined_instructions, instruction_to_outline,
555                 outlined_instruction);
556 
557     // Mark instruction_to_outline an output if it is used outside the
558     // subcomputation or is the output of the original computation (i.e. used
559     // externally).
560     if (instruction_to_outline->user_count() == 0 ||
561         IsUsedOutsideSubcomputation(*instruction_to_outline,
562                                     instruction_set_to_outline)) {
563       outputs.push_back(instruction_to_outline);
564     }
565   }
566 
567   if (outputs.size() != 1) {
568     string error_message =
569         "The subcomputation to outline has multiple outputs:\n";
570     for (HloInstruction* output : outputs) {
571       absl::StrAppend(&error_message, output->ToString(), "\n");
572     }
573     LOG(FATAL) << error_message;
574   }
575   HloInstruction* output = outputs[0];
576 
577   // Creates a call to the nested computation.
578   HloComputation* nested_computation = AddEmbeddedComputation(
579       builder.Build(FindOrDie(outlined_instructions, output)));
580   HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
581       output->shape(), arguments, nested_computation));
582 
583   VLOG(2) << "Outlining the following instructions";
584   for (auto* instruction_to_outline : instructions_to_outline) {
585     VLOG(2) << "  " << instruction_to_outline->ToString();
586   }
587   VLOG(2) << "as a call " << call->ToString();
588   VLOG(2) << "to " << nested_computation->ToString();
589 
590   TF_CHECK_OK(output->ReplaceAllUsesWith(call));
591   for (auto i = instructions_to_outline.rbegin();
592        i != instructions_to_outline.rend(); ++i) {
593     TF_CHECK_OK(computation->RemoveInstruction(*i));
594   }
595 
596   return call;
597 }
598 
instruction_count() const599 int64 HloModule::instruction_count() const {
600   int64_t n = 0;
601   for (const auto& computation : computations_) {
602     n += computation->instruction_count();
603   }
604   return n;
605 }
606 
MakeComputationPostOrder(const absl::flat_hash_set<HloComputation * > & allow_list) const607 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
608     const absl::flat_hash_set<HloComputation*>& allow_list) const {
609   std::vector<HloComputation*> filtered_post_order(allow_list.size());
610   auto post_order = this->MakeComputationPostOrder();
611 
612   int filtered_idx = 0;
613   for (auto& computation : post_order) {
614     if (allow_list.contains(computation)) {
615       filtered_post_order[filtered_idx] = computation;
616       filtered_idx += 1;
617     }
618   }
619 
620   return filtered_post_order;
621 }
622 
MakeComputationPostOrder() const623 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
624   // First determine all root computations by building a set of nonroot
625   // computations (computations which are called by an instruction in the
626   // module).
627   absl::flat_hash_set<HloComputation*> nonroot_computations;
628   for (auto& computation : computations_) {
629     for (auto* instruction : computation->instructions()) {
630       for (HloComputation* called_computation :
631            instruction->called_computations()) {
632         nonroot_computations.insert(called_computation);
633       }
634     }
635   }
636 
637   // Keep track of computations which have already been added to the post
638   // order. This prevents duplication as an embedded computation may be called
639   // from two different root computations.
640   absl::flat_hash_set<HloComputation*> added_computations;
641   std::vector<HloComputation*> post_order;
642   for (auto& computation : computations_) {
643     if (!nonroot_computations.contains(computation.get())) {
644       for (HloComputation* embedded_computation :
645            computation->MakeEmbeddedComputationsList()) {
646         if (!added_computations.contains(embedded_computation)) {
647           post_order.push_back(embedded_computation);
648           added_computations.insert(embedded_computation);
649         }
650       }
651       // Root computations should only be encountered once.
652       CHECK(!added_computations.contains(computation.get()));
653       post_order.push_back(computation.get());
654       added_computations.insert(computation.get());
655     }
656   }
657   if (post_order.size() != computations_.size()) {
658     for (HloComputation* computation : post_order) {
659       LOG(ERROR) << "Post Order: " << computation->name() << " ("
660                  << computation->parent()->name() << ")";
661     }
662     for (auto& computation : computations_) {
663       LOG(ERROR) << "Computations: " << computation->name() << " ("
664                  << computation->parent()->name() << ")";
665     }
666     LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
667                << " computation_count=" << computations_.size();
668   }
669   return post_order;
670 }
671 
672 namespace {
CompareComputationsByContent(const HloComputation * a,const HloComputation * b)673 bool CompareComputationsByContent(const HloComputation* a,
674                                   const HloComputation* b) {
675   if (a->instruction_count() != b->instruction_count()) {
676     return a->instruction_count() < b->instruction_count();
677   }
678   return a->ToString(HloPrintOptions::Fingerprint()) <
679          b->ToString(HloPrintOptions::Fingerprint());
680 }
681 
SortComputationsByContent(std::vector<HloComputation * > * computations)682 void SortComputationsByContent(std::vector<HloComputation*>* computations) {
683   absl::c_sort(*computations, CompareComputationsByContent);
684 }
685 
686 }  // anonymous namespace
687 
MakeComputationSorted() const688 std::vector<HloComputation*> HloModule::MakeComputationSorted() const {
689   std::vector<HloComputation*> result = MakeComputationPostOrder();
690   if (config().content_aware_computation_sorting()) {
691     SortComputationsByContent(&result);
692   }
693   return result;
694 }
695 
MakeNonfusionComputations() const696 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
697   std::vector<HloComputation*> result = MakeComputationPostOrder();
698   result.erase(std::remove_if(
699                    result.begin(), result.end(),
700                    [](HloComputation* c) { return c->IsFusionComputation(); }),
701                result.end());
702   return result;
703 }
704 
MakeNonfusionComputationsSorted() const705 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
706     const {
707   auto result = MakeNonfusionComputations();
708   if (config().content_aware_computation_sorting()) {
709     SortComputationsByContent(&result);
710   }
711   return result;
712 }
713 
Clone(const string & suffix) const714 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
715   return Clone(config(), suffix);
716 }
717 
Clone(const HloModuleConfig & config,const string & suffix) const718 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
719                                             const string& suffix) const {
720   VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
721   auto module = absl::make_unique<HloModule>(
722       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
723 
724   HloCloneContext context(module.get(), suffix);
725   auto cloned_computation = entry_computation_->Clone(suffix, &context);
726   module->AddEntryComputation(std::move(cloned_computation));
727   module->input_output_alias_config() = input_output_alias_config();
728   module->set_is_dynamic(is_dynamic());
729   if (has_schedule() && schedule().Verify().ok()) {
730     HloSchedule clone_schedule(module.get());
731     for (HloComputation* computation : computations()) {
732       if (schedule().is_computation_scheduled(computation)) {
733         HloComputation* new_computation = context.FindComputation(computation);
734         // The module being cloned may have computations that are dead, i.e.,
735         // unreachable from the entry computation. In that case, new_computation
736         // is nullptr.
737         if (new_computation != nullptr) {
738           HloInstructionSequence& clone_sequence =
739               clone_schedule.GetOrCreateSequence(new_computation);
740           for (const HloInstruction* instruction :
741                schedule().sequence(computation).instructions()) {
742             clone_sequence.push_back(context.GetInstruction(instruction));
743           }
744         }
745       }
746     }
747     TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
748   }
749   for (const auto& parameter_indices : CrossProgramPrefetches()) {
750     const auto& parameter = parameter_indices.first;
751     const auto& indices = parameter_indices.second;
752     module->AddCrossProgramPrefetch(parameter, indices);
753   }
754   return module;
755 }
756 
RemoveUnusedComputations()757 Status HloModule::RemoveUnusedComputations() {
758   std::string suffix = "tmp";
759   auto module = absl::make_unique<HloModule>(
760       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
761   HloCloneContext context(module.get(), suffix);
762   entry_computation_->Clone(suffix, &context);
763   std::vector<HloComputation*> to_remove;
764   for (auto computation : computations()) {
765     auto found_computation = context.FindComputation(computation);
766     if (found_computation == nullptr) {
767       to_remove.push_back(computation);
768     }
769   }
770   for (auto computation : to_remove) {
771     TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
772   }
773   return Status::OK();
774 }
775 
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)776 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
777                                                 HloCloneContext* context) {
778   HloComputation* new_computation;
779   if (context != nullptr) {
780     if ((new_computation = context->FindComputation(computation)) != nullptr) {
781       return new_computation;
782     }
783     new_computation =
784         AddEmbeddedComputation(computation->Clone(context->suffix(), context));
785   } else {
786     new_computation = AddEmbeddedComputation(computation->Clone(""));
787   }
788   return new_computation;
789 }
790 
RandomNew64() const791 uint64 HloModule::RandomNew64() const {
792   tensorflow::mutex_lock l(rng_mutex_);
793   return rng_();
794 }
795 
GetComputationWithName(absl::string_view name)796 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
797   auto computations_in_module = computations();
798   auto it = absl::c_find_if(
799       computations_in_module,
800       [&](HloComputation* computation) { return computation->name() == name; });
801   return it == computations_in_module.end() ? nullptr : *it;
802 }
803 
Hash() const804 uint64 HloModule::Hash() const {
805   uint64 result = entry_computation_layout().Hash();
806   // Use MakeComputationSorted() instead of MakeComputationPostOrder()
807   // because naming may affect the order of MakeComputationPostOrder() but not
808   // MakeComputationSorted().
809   for (auto* computation : MakeComputationSorted()) {
810     for (auto* instruction : computation->MakeInstructionPostOrder()) {
811       result = tensorflow::Hash64Combine(result, instruction->Hash());
812     }
813   }
814   return result;
815 }
816 
817 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
818 
819 }  // namespace xla
820