1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include <iterator>
19 #include <set>
20 #include <sstream>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/lib/gtl/map_util.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace xla {
38
HloModule(const string & name,const HloModuleConfig & config)39 HloModule::HloModule(const string& name, const HloModuleConfig& config)
40 : name_(NameUniquer::GetSanitizedName(name)),
41 config_(config),
42 unique_id_(next_unique_module_id_++) {}
43
set_schedule(HloSchedule schedule)44 Status HloModule::set_schedule(HloSchedule schedule) {
45 TF_RET_CHECK(schedule.module() == this);
46 TF_RETURN_IF_ERROR(schedule.Verify());
47 schedule_ = std::move(schedule);
48 return Status::OK();
49 }
50
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers)51 HloComputation* HloModule::AddComputationInternal(
52 std::unique_ptr<HloComputation> computation, bool is_entry,
53 bool uniquify_identifiers) {
54 if (is_entry) {
55 CHECK_EQ(nullptr, entry_computation_);
56 entry_computation_ = computation.get();
57
58 // If the module configuration has no entry layout computation set, create a
59 // default one based on the program shape.
60 if (!config_.has_entry_computation_layout()) {
61 config_.SetDefaultComputationLayout(
62 entry_computation_->ComputeProgramShape());
63 }
64 input_output_alias_config_ = HloInputOutputAliasConfig(
65 entry_computation_->root_instruction()->shape());
66 }
67
68 if (uniquify_identifiers) {
69 computation->UniquifyName(&computation_name_uniquer_);
70 for (auto* instruction : computation->instructions()) {
71 instruction->UniquifyName(&instruction_name_uniquer_);
72 }
73
74 // Pick unique IDs for each instruction.
75 for (auto* instruction : computation->instructions()) {
76 instruction->SetUniqueId(NewUniqueInstructionId());
77 }
78 // Set unique id to this computation.
79 CHECK_NE(computation->root_instruction()->unique_id(), -1)
80 << "Root has no valid id: " << computation->ToString();
81 computation->SetUniqueId(computation->root_instruction()->unique_id());
82 } else {
83 // Don't uniquify the names of the computation or instruction, but we must
84 // run the names through the uniquifiers to prevent future name collisions
85 // for computations and instructions created later. Also, set the
86 // next_unique_id_ to the one greater than the max unique id of any
87 // instruction (or the computation) to avoid ID collisions.
88 computation_name_uniquer_.GetUniqueName(computation->name());
89 for (auto* instruction : computation->instructions()) {
90 instruction_name_uniquer_.GetUniqueName(instruction->name());
91 next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
92 }
93 if (next_unique_id_ < computation->unique_id() + 1) {
94 next_unique_id_ = computation->unique_id() + 1;
95 }
96 }
97
98 computation->set_parent(this);
99 computations_.push_back(std::move(computation));
100 return computations_.back().get();
101 }
102
AddEntryComputation(std::unique_ptr<HloComputation> computation)103 HloComputation* HloModule::AddEntryComputation(
104 std::unique_ptr<HloComputation> computation) {
105 return AddComputationInternal(std::move(computation), /*is_entry=*/true,
106 /*uniquify_identifiers=*/true);
107 }
108
RemoveEmbeddedComputation(HloComputation * to_remove)109 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
110 auto it = absl::c_find_if(
111 computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
112 return comp.get() == to_remove;
113 });
114 TF_RET_CHECK(it->get() == to_remove);
115 computations_.erase(it);
116 return Status::OK();
117 }
118
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)119 HloComputation* HloModule::AddEmbeddedComputation(
120 std::unique_ptr<HloComputation> computation) {
121 return AddComputationInternal(std::move(computation), /*is_entry=*/false,
122 /*uniquify_identifiers=*/true);
123 }
124
ReplaceComputations(const std::unordered_map<HloComputation *,HloComputation * > & replacements)125 void HloModule::ReplaceComputations(
126 const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
127 // Replace all uses of non-canonical computations with their
128 // representatives.
129 std::vector<std::unique_ptr<HloComputation>> new_computations;
130 new_computations.reserve(computations_.size());
131
132 for (std::unique_ptr<HloComputation>& computation : computations_) {
133 for (auto* instruction : computation->instructions()) {
134 switch (instruction->opcode()) {
135 case HloOpcode::kCall:
136 case HloOpcode::kMap:
137 case HloOpcode::kReduce:
138 case HloOpcode::kReduceWindow:
139 case HloOpcode::kScatter: {
140 HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
141 replacements, instruction->to_apply(), nullptr);
142 if (new_arg != nullptr) {
143 instruction->set_to_apply(new_arg);
144 }
145 break;
146 }
147 case HloOpcode::kWhile: {
148 HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
149 replacements, instruction->while_condition(), nullptr);
150 if (new_condition != nullptr) {
151 instruction->set_while_condition(new_condition);
152 }
153 HloComputation* new_body = tensorflow::gtl::FindWithDefault(
154 replacements, instruction->while_body(), nullptr);
155 if (new_body != nullptr) {
156 instruction->set_while_body(new_body);
157 }
158 break;
159 }
160 case HloOpcode::kConditional: {
161 for (int b = 0; b < instruction->branch_count(); ++b) {
162 HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
163 replacements, instruction->branch_computation(b), nullptr);
164 if (new_computation != nullptr) {
165 instruction->set_branch_computation(b, new_computation);
166 }
167 }
168 break;
169 }
170 case HloOpcode::kSelectAndScatter: {
171 HloComputation* new_select = tensorflow::gtl::FindWithDefault(
172 replacements, instruction->select(), nullptr);
173 if (new_select != nullptr) {
174 instruction->set_select(new_select);
175 }
176 HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
177 replacements, instruction->scatter(), nullptr);
178 if (new_scatter != nullptr) {
179 instruction->set_scatter(new_scatter);
180 }
181 break;
182 }
183 default:
184 break;
185 }
186 }
187
188 if (replacements.find(computation.get()) == replacements.end()) {
189 new_computations.push_back(std::move(computation));
190 }
191 }
192
193 // Replace entry_computation if necessary.
194 entry_computation_ = tensorflow::gtl::FindWithDefault(
195 replacements, entry_computation_, entry_computation_);
196
197 computations_ = std::move(new_computations);
198 }
199
ToString(const HloPrintOptions & options) const200 string HloModule::ToString(const HloPrintOptions& options) const {
201 std::ostringstream s;
202 s << "HloModule " << name();
203 if (has_schedule()) {
204 TF_CHECK_OK(schedule().Verify());
205 s << ", is_scheduled=true";
206 }
207 s << "\n\n";
208 for (const HloComputation* computation : MakeComputationPostOrder()) {
209 if (computation == entry_computation()) {
210 s << "ENTRY ";
211 }
212 if (has_schedule() && schedule().is_computation_scheduled(computation)) {
213 s << computation->ToString(
214 options, schedule().sequence(computation).instructions())
215 << "\n\n";
216 } else {
217 s << computation->ToString(options) << "\n\n";
218 }
219 }
220 return s.str();
221 }
222
ToProto() const223 HloModuleProto HloModule::ToProto() const {
224 HloModuleProto proto;
225 proto.set_id(unique_id_);
226 proto.set_name(name_);
227 proto.set_entry_computation_name(entry_computation_->name());
228 proto.set_entry_computation_id(entry_computation_->unique_id());
229 for (const HloComputation* computation : MakeComputationPostOrder()) {
230 HloComputationProto computation_proto = computation->ToProto();
231 proto.add_computations()->Swap(&computation_proto);
232 }
233 if (has_schedule()) {
234 *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
235 }
236 *proto.mutable_host_program_shape() =
237 entry_computation_layout().ComputeProgramShape().ToProto();
238 *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
239 *proto.mutable_dynamic_parameter_binding() =
240 dynamic_parameter_binding().ToProto();
241 return proto;
242 }
243
CheckUniqueNamesAndIdsForComputationsAndInstructions() const244 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
245 absl::flat_hash_set<string> computation_names;
246 absl::flat_hash_set<int> computation_ids;
247 absl::flat_hash_set<string> instruction_names;
248 absl::flat_hash_set<int> instruction_ids;
249
250 for (const HloComputation* computation : computations()) {
251 TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
252 << "Computation name is not unique: " << computation->name();
253 computation_names.insert(computation->name());
254
255 TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
256 << "Computation id is not unique: " << computation->unique_id();
257 computation_ids.insert(computation->unique_id());
258
259 for (const HloInstruction* instruction : computation->instructions()) {
260 TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
261 << "Instruction name is not unique: " << instruction->name();
262 instruction_names.insert(instruction->name());
263
264 TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
265 << "Instruction id is not unique: " << instruction->unique_id();
266 instruction_ids.insert(instruction->unique_id());
267 }
268 }
269 return Status::OK();
270 }
271
272 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config)273 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
274 const HloModuleProto& proto, const HloModuleConfig& module_config) {
275 VLOG(2) << "CreateFromProto()";
276 XLA_VLOG_LINES(3, proto.DebugString());
277
278 // The ProgramShape in the passed in module config must match the shapes of
279 // the entry parameters and root.
280 TF_RET_CHECK(proto.has_host_program_shape())
281 << "No program shape found in the proto";
282 ProgramShape expected_program_shape(proto.host_program_shape());
283 TF_RET_CHECK(expected_program_shape.parameters_size() ==
284 module_config.entry_computation_layout().parameter_count());
285 for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
286 const Shape& parameter_shape =
287 module_config.entry_computation_layout().parameter_layout(i).shape();
288 TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
289 parameter_shape))
290 << "HloModuleConfig has different shape for parameter " << i
291 << " than the HLO module. Expected: "
292 << ShapeUtil::HumanStringWithLayout(
293 expected_program_shape.parameters(i))
294 << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
295 }
296 const Shape& result_shape =
297 module_config.entry_computation_layout().result_layout().shape();
298 TF_RET_CHECK(
299 ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
300 << "HloModuleConfig has different result shape than the HLO module. "
301 "Expected: "
302 << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
303 << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
304
305 absl::flat_hash_map<int64, HloComputation*> computation_map;
306 absl::flat_hash_map<HloComputation*, int64> to_proto_id;
307 std::vector<std::unique_ptr<HloComputation>> computations;
308 HloComputation* entry = nullptr;
309 for (const HloComputationProto& computation_proto : proto.computations()) {
310 TF_ASSIGN_OR_RETURN(
311 std::unique_ptr<HloComputation> computation,
312 HloComputation::CreateFromProto(computation_proto, computation_map));
313 CHECK_NE(computation.get(), nullptr);
314 int64 computation_id = computation_proto.id();
315 TF_RET_CHECK(computation_id != -1);
316 TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
317 computation_map[computation_id] = computation.get();
318 to_proto_id[computation.get()] = computation_id;
319 if (computation_id == proto.entry_computation_id()) {
320 entry = computation.get();
321 }
322 computations.push_back(std::move(computation));
323 }
324 TF_RET_CHECK(entry != nullptr);
325
326 auto module = absl::make_unique<HloModule>(proto.name(), module_config);
327
328 // Sort the computations in the proto id's order.
329 absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
330 const std::unique_ptr<HloComputation>& b) {
331 return to_proto_id[a.get()] < to_proto_id[b.get()];
332 });
333
334 // Add sorted computations to the module.
335 for (auto& computation : computations) {
336 bool is_entry = computation.get() == entry;
337 // Don't uniquify names because we want names to be stable across
338 // serialization and deserialization.
339 module->AddComputationInternal(std::move(computation), is_entry,
340 /*uniquify_identifiers=*/false);
341 }
342 TF_RET_CHECK(module->entry_computation_ != nullptr);
343
344 TF_ASSIGN_OR_RETURN(
345 module->input_output_alias_config_,
346 HloInputOutputAliasConfig::CreateFromProto(
347 entry->ComputeProgramShape().result(), proto.input_output_alias()));
348
349 // Because we didn't uniquify the names or the ids, double-check that the
350 // instruction and computation names and ids are unique from the proto.
351 TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
352 DynamicParameterBinding::CreateFromProto(
353 proto.dynamic_parameter_binding()));
354
355 TF_RETURN_IF_ERROR(
356 module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
357
358 if (proto.has_schedule()) {
359 TF_ASSIGN_OR_RETURN(
360 HloSchedule schedule,
361 HloSchedule::CreateFromProto(module.get(), proto.schedule()));
362 TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
363 }
364
365 return std::move(module);
366 }
367
368 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options)369 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
370 const HloModuleProto& module, const DebugOptions& debug_options) {
371 TF_RET_CHECK(module.has_host_program_shape())
372 << "No program shape found in the proto";
373 ProgramShape program_shape(module.host_program_shape());
374
375 HloModuleConfig module_config(ProgramShape{program_shape});
376 module_config.set_debug_options(debug_options);
377
378 // The module config is constructed with default layouts regardless of what is
379 // passed in via the ProgramShape. Set the layouts to the appropriate values.
380 ComputationLayout* entry_layout =
381 module_config.mutable_entry_computation_layout();
382 for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
383 TF_RETURN_IF_ERROR(
384 entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
385 program_shape.parameters(i)));
386 }
387 TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
388 program_shape.result()));
389 return module_config;
390 }
391
392 namespace {
393 // Returns whether `hlo` is used outside the given subcomputation.
394 // `instructions_in_subcomputation` is the instruction set of the given
395 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)396 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
397 const absl::flat_hash_set<HloInstruction*>&
398 instructions_in_subcomputation) {
399 return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
400 return !instructions_in_subcomputation.contains(user);
401 });
402 }
403 } // anonymous namespace
404
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)405 HloInstruction* HloModule::OutlineExpressionFromComputation(
406 absl::Span<HloInstruction* const> instructions_to_outline,
407 const string& outlined_computation_name, HloComputation* computation) {
408 auto builder = HloComputation::Builder(outlined_computation_name);
409
410 // A map from original instructions to their counterparts in the new outlined
411 // function.
412 absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
413 // A set that contains all instructions to be outlined.
414 absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
415 instructions_to_outline.begin(), instructions_to_outline.end());
416 std::vector<HloInstruction*> arguments;
417 std::vector<HloInstruction*> outputs;
418 int64 parameter_count = 0;
419 for (HloInstruction* instruction_to_outline : instructions_to_outline) {
420 // Clone the original instruction.
421 HloInstruction* outlined_instruction =
422 builder.AddInstruction(instruction_to_outline->Clone());
423
424 // Replace its operands to their counterparts in the new function.
425 for (int64 operand_num = 0;
426 operand_num < outlined_instruction->operand_count(); ++operand_num) {
427 HloInstruction* old_operand =
428 outlined_instruction->mutable_operand(operand_num);
429
430 HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
431 if (*operand_slot == nullptr) {
432 // Because instructions_to_outline is in topological order, if
433 // old_operand is not in outlined_instructions, old_operand must be an
434 // input of the outlined subcomputation and thus should be represented
435 // as a parameter in the new function.
436 arguments.push_back(old_operand);
437 *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
438 parameter_count, old_operand->shape(), "p"));
439 ++parameter_count;
440 }
441 TF_CHECK_OK(
442 outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
443 }
444
445 // Insert the new instruction into the outlined_instructions map.
446 InsertOrDie(&outlined_instructions, instruction_to_outline,
447 outlined_instruction);
448
449 // Mark instruction_to_outline an output if it is used outside the
450 // subcomputation or is the output of the original computation (i.e. used
451 // externally).
452 if (instruction_to_outline->user_count() == 0 ||
453 IsUsedOutsideSubcomputation(*instruction_to_outline,
454 instruction_set_to_outline)) {
455 outputs.push_back(instruction_to_outline);
456 }
457 }
458
459 if (outputs.size() != 1) {
460 string error_message =
461 "The subcomputation to outline has multiple outputs:\n";
462 for (HloInstruction* output : outputs) {
463 absl::StrAppend(&error_message, output->ToString(), "\n");
464 }
465 LOG(FATAL) << error_message;
466 }
467 HloInstruction* output = outputs[0];
468
469 // Creates a call to the nested computation.
470 HloComputation* nested_computation = AddEmbeddedComputation(
471 builder.Build(FindOrDie(outlined_instructions, output)));
472 HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
473 output->shape(), arguments, nested_computation));
474
475 VLOG(2) << "Outlining the following instructions";
476 for (auto* instruction_to_outline : instructions_to_outline) {
477 VLOG(2) << " " << instruction_to_outline->ToString();
478 }
479 VLOG(2) << "as a call " << call->ToString();
480 VLOG(2) << "to " << nested_computation->ToString();
481
482 TF_CHECK_OK(output->ReplaceAllUsesWith(call));
483 for (auto i = instructions_to_outline.rbegin();
484 i != instructions_to_outline.rend(); ++i) {
485 TF_CHECK_OK(computation->RemoveInstruction(*i));
486 }
487
488 return call;
489 }
490
instruction_count() const491 int64 HloModule::instruction_count() const {
492 int64 n = 0;
493 for (const auto& computation : computations_) {
494 n += computation->instruction_count();
495 }
496 return n;
497 }
498
MakeComputationPostOrder() const499 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
500 // First determine all root computations by building a set of nonroot
501 // computations (computations which are called by an instruction in the
502 // module).
503 absl::flat_hash_set<HloComputation*> nonroot_computations;
504 for (auto& computation : computations_) {
505 for (auto* instruction : computation->instructions()) {
506 for (HloComputation* called_computation :
507 instruction->called_computations()) {
508 nonroot_computations.insert(called_computation);
509 }
510 }
511 }
512
513 // Keep track of computations which have already been added to the post
514 // order. This prevents duplication as an embedded computation may be called
515 // from two different root computations.
516 absl::flat_hash_set<HloComputation*> added_computations;
517 std::vector<HloComputation*> post_order;
518 for (auto& computation : computations_) {
519 if (!nonroot_computations.contains(computation.get())) {
520 for (HloComputation* embedded_computation :
521 computation->MakeEmbeddedComputationsList()) {
522 if (!added_computations.contains(embedded_computation)) {
523 post_order.push_back(embedded_computation);
524 added_computations.insert(embedded_computation);
525 }
526 }
527 // Root computations should only be encountered once.
528 CHECK(!added_computations.contains(computation.get()));
529 post_order.push_back(computation.get());
530 added_computations.insert(computation.get());
531 }
532 }
533 if (post_order.size() != computations_.size()) {
534 for (HloComputation* computation : post_order) {
535 LOG(ERROR) << "Post Order: " << computation->name() << " ("
536 << computation->parent()->name() << ")";
537 }
538 for (auto& computation : computations_) {
539 LOG(ERROR) << "Computations: " << computation->name() << " ("
540 << computation->parent()->name() << ")";
541 }
542 LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
543 << " computation_count=" << computations_.size();
544 }
545 return post_order;
546 }
547
MakeNonfusionComputations() const548 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
549 std::vector<HloComputation*> result;
550 for (auto* c : computations()) {
551 if (c->IsFusionComputation()) {
552 continue;
553 }
554 result.push_back(c);
555 }
556 return result;
557 }
558
Clone(const string & suffix) const559 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
560 return Clone(config(), suffix);
561 }
562
Clone(const HloModuleConfig & config,const string & suffix) const563 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
564 const string& suffix) const {
565 VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
566 auto module = absl::make_unique<HloModule>(
567 absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
568
569 HloCloneContext context(module.get(), suffix);
570 auto cloned_computation = entry_computation_->Clone(suffix, &context);
571 module->AddEntryComputation(std::move(cloned_computation));
572
573 if (has_schedule() && schedule().Verify().ok()) {
574 HloSchedule clone_schedule(module.get());
575 for (HloComputation* computation : computations()) {
576 if (schedule().is_computation_scheduled(computation)) {
577 HloInstructionSequence& clone_sequence =
578 clone_schedule.GetOrCreateSequence(
579 context.GetComputation(computation));
580 for (const HloInstruction* instruction :
581 schedule().sequence(computation).instructions()) {
582 clone_sequence.push_back(context.GetInstruction(instruction));
583 }
584 }
585 }
586 TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
587 }
588 return module;
589 }
590
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)591 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
592 HloCloneContext* context) {
593 HloComputation* new_computation;
594 if (context != nullptr) {
595 if ((new_computation = context->FindComputation(computation)) != nullptr) {
596 return new_computation;
597 }
598 new_computation =
599 AddEmbeddedComputation(computation->Clone(context->suffix(), context));
600 } else {
601 new_computation = AddEmbeddedComputation(computation->Clone(""));
602 }
603 return new_computation;
604 }
605
RandomNew64() const606 uint64 HloModule::RandomNew64() const {
607 tensorflow::mutex_lock l(rng_mutex_);
608 return rng_();
609 }
610
GetComputationWithName(absl::string_view name)611 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
612 auto computations_in_module = computations();
613 auto it = absl::c_find_if(
614 computations_in_module,
615 [&](HloComputation* computation) { return computation->name() == name; });
616 return it == computations_in_module.end() ? nullptr : *it;
617 }
618
619 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
620
621 } // namespace xla
622