1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 18 19 #include <atomic> 20 #include <functional> 21 #include <list> 22 #include <memory> 23 #include <optional> 24 #include <random> 25 #include <string> 26 #include <utility> 27 #include <vector> 28 29 #include "absl/strings/cord.h" 30 #include "absl/strings/string_view.h" 31 #include "absl/types/span.h" 32 #include "tensorflow/compiler/xla/iterator_util.h" 33 #include "tensorflow/compiler/xla/service/compilation_environments.h" 34 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" 35 #include "tensorflow/compiler/xla/service/hlo.pb.h" 36 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 37 #include "tensorflow/compiler/xla/service/hlo_computation.h" 38 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" 39 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 40 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 41 #include "tensorflow/compiler/xla/service/hlo_module_metadata.h" 42 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 43 #include "tensorflow/compiler/xla/service/name_uniquer.h" 44 #include "tensorflow/compiler/xla/types.h" 45 #include "tensorflow/core/lib/gtl/iterator_range.h" 46 #include "tensorflow/core/platform/logging.h" 47 48 namespace xla { 49 50 using LayoutCanonicalizationCallback = 51 std::function<StatusOr<std::pair<std::vector<Shape>, Shape>>( 52 const HloModule& module)>; 53 54 // Describes a compilation unit at the HLO level. 55 // 56 // HloModule is the top-level unit in the HLO IR. It corresponds to a whole 57 // "program". Running a module, from beginning to end, is the only way to run 58 // an XLA program. 59 // 60 // A module contains one "entry computation"; this HloComputation is like main() 61 // in a C program. The result of running the module is the result of running 62 // this computation. 63 // 64 // A module also contains some number of "nested computations". Each nested 65 // computation is attached to an HloInstruction within some other computation. 66 // The meaning of the nested computation depends on the instruction it's 67 // attached to. 68 class HloModule { 69 public: 70 // Constructor. 71 HloModule(const std::string& name, HloModuleConfig config); ~HloModule()72 virtual ~HloModule() {} 73 74 // Adds an entry computation to the module. A module can only have one entry 75 // computation. Returns a pointer to the newly added computation. 76 HloComputation* AddEntryComputation( 77 std::unique_ptr<HloComputation> computation); 78 79 // Same as the AddEntryComputation function above but the module's 80 // entry_computation_layout is updated to match the layout of the new entry 81 // computation. 82 HloComputation* AddEntryComputationWithLayouts( 83 std::unique_ptr<HloComputation> computation); 84 85 // Replaces the current entry computation with another computation. 86 // The new entry computation must be a computation that is already in the 87 // module. 88 void ReplaceEntryComputation(HloComputation* entry_computation); 89 90 // Adds an embedded computation to the module. 91 HloComputation* AddEmbeddedComputation( 92 std::unique_ptr<HloComputation> computation); 93 94 // Removes an embedded computation. 95 Status RemoveEmbeddedComputation(HloComputation* to_remove); 96 97 // Removes unused computations. 98 Status RemoveUnusedComputations(); 99 100 // Replaces all uses of computations that are keys of 'replacements' with 101 // the corresponding values in 'replacements'. Replaces the entry computation, 102 // if applicable. 103 // 104 // This function iterates over all instructions in the module to find 105 // computations to replace. We could speed it up by keeping track of users of 106 // computations. 107 void ReplaceComputations( 108 const absl::flat_hash_map<HloComputation*, HloComputation*>& 109 replacements); 110 name()111 const std::string& name() const { return name_; } set_name(std::string name)112 void set_name(std::string name) { name_ = std::move(name); } 113 114 // Returns a deep copy of this module including all computations. 115 std::unique_ptr<HloModule> Clone(const std::string& suffix = "clone") const; 116 std::unique_ptr<HloModule> Clone(const HloModuleConfig& config, 117 const std::string& suffix = "clone") const; 118 119 // Performs a deep clone of the computation, by recursively cloning all 120 // the called computations as well. If the clone context is specified, it 121 // will be populated with the cloned object mappings. 122 HloComputation* DeepCloneComputation(HloComputation* computation, 123 HloCloneContext* context = nullptr); 124 125 // Return a pointer to the entry computation of the module. entry_computation()126 HloComputation* entry_computation() const { 127 CHECK_NE(nullptr, entry_computation_); 128 return entry_computation_; 129 } 130 has_entry_computation()131 bool has_entry_computation() const { return entry_computation_ != nullptr; } 132 133 // Returns the root instruction shape of entry computation. 134 // 135 // Precondition: entry_computation_ is not nullptr. result_shape()136 const Shape& result_shape() const { 137 CHECK_NE(nullptr, entry_computation_); 138 return entry_computation()->root_instruction()->shape(); 139 } 140 141 // Creates the ComputationLayout which describes the current status of the HLO 142 // module entry computation. compute_computation_layout()143 ComputationLayout compute_computation_layout() const { 144 return ComputationLayout(entry_computation()->ComputeProgramShape(), 145 /*ignore_layouts=*/false); 146 } 147 mutable_entry_computation_layout()148 ComputationLayout* mutable_entry_computation_layout() { 149 return config_.mutable_entry_computation_layout(); 150 } 151 entry_computation_layout()152 const ComputationLayout& entry_computation_layout() const { 153 return config_.entry_computation_layout(); 154 } 155 set_use_auto_spmd_partitioning(bool use)156 void set_use_auto_spmd_partitioning(bool use) { 157 use_auto_spmd_partitioning_ = use; 158 } 159 use_auto_spmd_partitioning()160 bool use_auto_spmd_partitioning() const { 161 return use_auto_spmd_partitioning_; 162 } 163 164 // Based on module's entry_computation sharded shapes, 165 // layout_canonicalization_callback_ computes and 166 // returns <argument_layouts, result_layout> for module's entry computation. 167 // argument_layouts is std::vector<Shape> and results_layout is Shape. 168 // layout_canonicalization_callback_ is used only when 169 // use_auto_spmd_partitioning_ = true. set_layout_canonicalization_callback(LayoutCanonicalizationCallback callback)170 void set_layout_canonicalization_callback( 171 LayoutCanonicalizationCallback callback) { 172 layout_canonicalization_callback_ = std::move(callback); 173 } 174 layout_canonicalization_callback()175 LayoutCanonicalizationCallback layout_canonicalization_callback() const { 176 return layout_canonicalization_callback_; 177 } 178 179 // Generates a hash value of an HLO module. Hash considers 180 // information on opcode, shape, operands, and typically a root instruction. 181 // This function returns the same hash value for equivalent HLO modules, 182 // with respect to HloInstruction::Identical() method. 183 template <typename H> AbslHashValue(H h,const HloModule & module)184 friend H AbslHashValue(H h, const HloModule& module) { 185 h = H::combine(std::move(h), module.entry_computation_layout()); 186 // Use MakeComputationSorted() instead of MakeComputationPostOrder() 187 // because naming may affect the order of MakeComputationPostOrder() but not 188 // MakeComputationSorted(). 189 auto computations = module.MakeComputationSorted(); 190 for (auto* computation : computations) { 191 h = H::combine(std::move(h), *computation); 192 } 193 return H::combine(std::move(h), computations.size()); 194 } 195 196 // Gets the computations in this module. 197 // 198 // Returns a view of HloComputation*s, so you can iterate over this in the 199 // natural way: 200 // 201 // for (HloComputation* c : module->computations()) { ... } 202 // 203 tensorflow::gtl::iterator_range<UnwrappingIterator< 204 std::vector<std::unique_ptr<HloComputation>>::const_iterator>> computations()205 computations() const { 206 return {MakeUnwrappingIterator(computations_.begin()), 207 MakeUnwrappingIterator(computations_.end())}; 208 } 209 tensorflow::gtl::iterator_range<UnwrappingIterator< 210 std::vector<std::unique_ptr<HloComputation>>::iterator>> computations()211 computations() { 212 return {MakeUnwrappingIterator(computations_.begin()), 213 MakeUnwrappingIterator(computations_.end())}; 214 } 215 216 // Similar as above, but return a filtered view of computations for specified 217 // `execution_threads`. Empty `execution_threads` list means all execution 218 // threads are included. 219 tensorflow::gtl::iterator_range<FilteringUnwrappingIterator< 220 std::vector<std::unique_ptr<HloComputation>>::const_iterator, 221 std::function<bool(const HloComputation*)>>> computations(const absl::flat_hash_set<absl::string_view> & execution_threads)222 computations( 223 const absl::flat_hash_set<absl::string_view>& execution_threads) const { 224 // Pass execution_threads by value to the predicate to ensure it lives 225 // beyond this function. 226 std::function<bool(const HloComputation*)> pred = 227 [execution_threads](const HloComputation* computation) { 228 if (execution_threads.empty()) { 229 return true; 230 } 231 return execution_threads.contains(computation->execution_thread()); 232 }; 233 return MakeFilteringUnwrappingIteratorRange(computations_.begin(), 234 computations_.end(), pred); 235 } 236 237 // Returns the computation in this module that has the name `name`. Returns 238 // null if there is no such computation. 239 HloComputation* GetComputationWithName(absl::string_view name); 240 241 // Gets the number of computations in this module. computation_count()242 int64_t computation_count() const { return computations_.size(); } 243 244 // Returns the mutable computation for the given index. mutable_computation(int64_t idx)245 HloComputation* mutable_computation(int64_t idx) { 246 CHECK(idx >= 0 && idx < computations_.size()); 247 return computations_[idx].get(); 248 } 249 250 // Gets the number of instructions in this module. 251 int64_t instruction_count() const; 252 253 // Deallocate removed instructions in each computation. Cleanup()254 void Cleanup() { 255 for (auto& comp : computations_) { 256 comp->Cleanup(); 257 } 258 } 259 260 // Compute and return a post order of all computations in the module. The sort 261 // is defined like so: if computation A has an instruction which calls 262 // computation B, then A will appear after B in the sort. MakeComputationPostOrder()263 std::vector<HloComputation*> MakeComputationPostOrder() const { 264 return MakeComputationPostOrder({}); 265 } 266 // Similar as above but only returns computations with specified 267 // `execution_threads`. Empty `execution_threads` list means all execution 268 // threads are included. 269 std::vector<HloComputation*> MakeComputationPostOrder( 270 const absl::flat_hash_set<absl::string_view>& execution_threads) const; 271 // Same as MakeComputationPostOrder() but only returns the computations that 272 // are on specified `execution_threads` and are also found in the passed in 273 // allowList. Empty `execution_threads` list means all execution threads are 274 // included. 275 std::vector<HloComputation*> MakeComputationPostOrder( 276 const absl::flat_hash_set<absl::string_view>& execution_threads, 277 const absl::flat_hash_set<HloComputation*>& allow_list) const; 278 279 // Same as MakeComputationPostOrder() but sorting the computations by their 280 // contents. The order is longer post order. MakeComputationSorted()281 std::vector<HloComputation*> MakeComputationSorted() const { 282 return MakeComputationSorted({}); 283 } 284 // Same as above but only for specified `execution_threads`. Empty 285 // `execution_threads` list means all execution threads are included. 286 std::vector<HloComputation*> MakeComputationSorted( 287 const absl::flat_hash_set<absl::string_view>& execution_threads) const; 288 289 // Gets the computations in this module which aren't for fusion nodes. 290 // 291 // Postcondition: All computations in the returned list have 292 // !IsFusionComputation(). 293 // 294 // Note: Callers can and do rely on the return value here being a *snapshot* 295 // of the module's non-fusion computations -- that is, it's OK to add or 296 // remove computations from a module while iterating over 297 // MakeNonfusionComputations(). MakeNonfusionComputations()298 std::vector<HloComputation*> MakeNonfusionComputations() const { 299 return MakeNonfusionComputations({}); 300 } 301 // Same as above but only for specified `execution_threads`. Empty 302 // `execution_threads` list means all execution threads are included. 303 std::vector<HloComputation*> MakeNonfusionComputations( 304 const absl::flat_hash_set<absl::string_view>& execution_threads) const; 305 306 // Same as MakeNonfusionComputations() but sorting computations by content. MakeNonfusionComputationsSorted()307 std::vector<HloComputation*> MakeNonfusionComputationsSorted() const { 308 return MakeNonfusionComputationsSorted({}); 309 } 310 // Same as above but only for specified `execution_threads`. Empty 311 // `execution_threads` list means all execution threads are included. 312 std::vector<HloComputation*> MakeNonfusionComputationsSorted( 313 const absl::flat_hash_set<absl::string_view>& execution_threads) const; 314 config()315 HloModuleConfig& config() { return config_; } config()316 const HloModuleConfig& config() const { return config_; } set_config(const HloModuleConfig & config)317 void set_config(const HloModuleConfig& config) { config_ = config; } 318 is_dynamic()319 bool is_dynamic() const { return is_dynamic_; } set_is_dynamic(bool is_dynamic)320 void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } 321 322 // Return a string representation of the module. 323 // 324 // (We express the default options using an overload rather than a default 325 // param because gdb ignores default params, but does resolve overloads.) ToString()326 std::string ToString() const { return ToString(HloPrintOptions()); } 327 std::string ToString(const HloPrintOptions& options) const; 328 329 // Returns a Cord representation of the module. 330 // 331 // (We express the default options using an overload rather than a default 332 // param because gdb ignores default params, but does resolve overloads.) ToCord()333 absl::Cord ToCord() const { return ToCord(HloPrintOptions()); } 334 absl::Cord ToCord(const HloPrintOptions& options) const; 335 336 // Convert an HloModule to or from a proto. 337 HloModuleProto ToProto() const; 338 static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( 339 const HloModuleProto& proto, const HloModuleConfig& module_config, 340 bool prohibit_empty_literal = true); 341 342 // Creates and returns an HloModuleConfig with an appropriate program shape 343 // for the HLO module in the given proto. 344 static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( 345 const HloModuleProto& module, const DebugOptions& debug_options, 346 const ExecutionOptions* execution_options = nullptr); 347 348 // Creates and returns an HloModuleConfig with an appropriate program shape 349 // for the HLO module in the given proto. 350 static StatusOr<HloModuleConfig> CreateModuleConfigFromShape( 351 const ProgramShape& program_shape, const DebugOptions& debug_options, 352 const ExecutionOptions* execution_options = nullptr); 353 354 // Outlines the given expression from the given computation. 355 // instructions_to_outline contains the instructions that form the expression. 356 // 357 // Precondition: instructions in instructions_to_outline are in topological 358 // order (root of outlined instructions last). TODO(jingyue): takes a set of 359 // instructions and topologically sorts them. 360 HloInstruction* OutlineExpressionFromComputation( 361 absl::Span<HloInstruction* const> instructions_to_outline, 362 const std::string& outlined_computation_name, 363 HloComputation* computation); 364 365 // Returns a randomly generated uint64_t. 366 uint64_t RandomNew64() const; 367 368 // Returns the NameUniquer for uniquing instruction names in this module. instruction_name_uniquer()369 NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } 370 371 // Assign a new unique dense id for an instruction NewUniqueInstructionId()372 int NewUniqueInstructionId() { 373 int result = next_unique_id_; 374 next_unique_id_++; 375 return result; 376 } 377 378 // input_output_alias_config indicates the list of aliased buffers that are 379 // expected from the module. input_output_alias_config()380 HloInputOutputAliasConfig& input_output_alias_config() { 381 return input_output_alias_config_; 382 } input_output_alias_config()383 const HloInputOutputAliasConfig& input_output_alias_config() const { 384 return input_output_alias_config_; 385 } 386 387 // DynamicParameterBinding holds the list of bindings that indicates which 388 // parameter dimensions are dynamic and which parameters represent their 389 // runtime value. dynamic_parameter_binding()390 DynamicParameterBinding& dynamic_parameter_binding() { 391 return dynamic_parameter_binding_; 392 } dynamic_parameter_binding()393 const DynamicParameterBinding& dynamic_parameter_binding() const { 394 return dynamic_parameter_binding_; 395 } 396 397 // Returns an id that is unique to this module across all modules created over 398 // the lifetime of this process. unique_id()399 int unique_id() const { return unique_id_; } 400 401 // Sets the schedule of the module to the given schedule. 402 Status set_schedule(HloSchedule schedule); 403 404 // Clears the schedule of the module. clear_schedule()405 void clear_schedule() { schedule_.reset(); } 406 407 // Returns true if the module has a schedule set. has_schedule()408 bool has_schedule() const { return schedule_.has_value(); } 409 410 // Returns the schedule of the module. CHECK fails if no schedule is set. schedule()411 const HloSchedule& schedule() const { return *schedule_; } schedule()412 HloSchedule& schedule() { return *schedule_; } 413 AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)414 HloComputation* AddComputationAndUnifyNamesAndIds( 415 std::unique_ptr<HloComputation> computation, bool is_entry) { 416 computation->ClearUniqueIdInternal(); 417 for (auto* instruction : computation->instructions()) { 418 instruction->ClearUniqueIdInternal(); 419 } 420 return AddComputationInternal(std::move(computation), is_entry, 421 /*uniquify_identifiers=*/true, 422 /*preserve_entry_layouts=*/true); 423 } 424 SetAndUniquifyInstrName(HloInstruction * instr,absl::string_view name)425 void SetAndUniquifyInstrName(HloInstruction* instr, absl::string_view name) { 426 instr->SetAndSanitizeName(name); 427 instr->UniquifyName(&instruction_name_uniquer_); 428 } 429 430 Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; 431 432 // Checks if this config has a list of entry parameters' HLO shardings for 433 // SPMD. has_spmd_parameters_shardings()434 bool has_spmd_parameters_shardings() const { 435 return spmd_parameters_shardings_.has_value(); 436 } 437 438 // Getter and setter for the list of entry parameters' HLO shardings for SPMD. spmd_parameters_shardings()439 const std::vector<HloSharding>& spmd_parameters_shardings() const { 440 CHECK(spmd_parameters_shardings_.has_value()); 441 return *spmd_parameters_shardings_; 442 } set_spmd_parameters_shardings(const std::vector<HloSharding> & shardings)443 void set_spmd_parameters_shardings( 444 const std::vector<HloSharding>& shardings) { 445 spmd_parameters_shardings_ = shardings; 446 } 447 448 // Checks if this config has the entry computation output's HLO sharding for 449 // SPMD. has_spmd_output_sharding()450 bool has_spmd_output_sharding() const { 451 return spmd_output_sharding_.has_value(); 452 } 453 454 // Getter and setter for the entry computation output's HLO shardings for 455 // SPMD. spmd_output_sharding()456 const HloSharding& spmd_output_sharding() const { 457 CHECK(spmd_output_sharding_.has_value()); 458 return *spmd_output_sharding_; 459 } set_spmd_output_sharding(const HloSharding & sharding)460 void set_spmd_output_sharding(const HloSharding& sharding) { 461 spmd_output_sharding_ = sharding; 462 } 463 464 // Add a program argument to be prefetched across programs. AddCrossProgramPrefetch(int64_t parameter,const ShapeIndex & index)465 void AddCrossProgramPrefetch(int64_t parameter, const ShapeIndex& index) { 466 cross_program_prefetches_.emplace_back(parameter, index); 467 } 468 469 // Get the list of program arguments to be prefetch across programs. 470 const absl::Span<const std::pair<int64_t, ShapeIndex>> CrossProgramPrefetches()471 CrossProgramPrefetches() const { 472 return cross_program_prefetches_; 473 } 474 metadata()475 const HloModuleMetadata& metadata() const { return metadata_; } metadata()476 HloModuleMetadata* metadata() { return &metadata_; } 477 478 // Moves (not copies) metadata from this HloModule to `module`. To be used 479 // in cases like HloModuleGroup::ReplaceModule when metadata should be 480 // transferred out of a module before it's destroyed. MoveMetadataToModule(HloModule * module)481 void MoveMetadataToModule(HloModule* module) { 482 module->metadata_ = std::move(metadata_); 483 } 484 profile_version()485 int64_t profile_version() const { return profile_version_; } 486 set_profile_version(int64_t profile_version)487 void set_profile_version(int64_t profile_version) { 488 profile_version_ = profile_version; 489 } 490 add_profile_info(const HloModuleProto::ProfileInfo & profile_info)491 void add_profile_info(const HloModuleProto::ProfileInfo& profile_info) { 492 profile_info_list_.push_back(profile_info); 493 } 494 set_profile_info(const std::vector<HloModuleProto::ProfileInfo> & profile_info)495 void set_profile_info( 496 const std::vector<HloModuleProto::ProfileInfo>& profile_info) { 497 profile_info_list_ = profile_info; 498 } 499 profile_info()500 const std::vector<HloModuleProto::ProfileInfo>& profile_info() const { 501 return profile_info_list_; 502 } 503 set_relative_speedup(double relative_speedup)504 void set_relative_speedup(double relative_speedup) { 505 relative_speedup_ = relative_speedup; 506 } 507 508 // Sets the **unoptimized** fingerprint for the module. This fingerprint is 509 // prior to any optimizations. set_autofdo_fingerprint(absl::string_view fingerprint)510 void set_autofdo_fingerprint(absl::string_view fingerprint) { 511 autofdo_fingerprint_ = std::string(fingerprint); 512 } 513 autofdo_fingerprint()514 absl::string_view autofdo_fingerprint() const { return autofdo_fingerprint_; } 515 comp_envs()516 CompilationEnvironments& comp_envs() const { return *comp_envs_; } 517 518 private: 519 // This constructor is used in Clone() to copy the ComputationEnvironments. 520 // comp_envs may be null, in which case a clean one will be created. 521 HloModule(const std::string& name, HloModuleConfig config, 522 std::unique_ptr<CompilationEnvironments> comp_envs); 523 524 HloComputation* AddComputationInternal( 525 std::unique_ptr<HloComputation> computation, bool is_entry, 526 bool uniquify_identifiers, bool preserve_entry_layouts); 527 528 std::string name_; 529 HloModuleConfig config_; 530 HloComputation* entry_computation_ = nullptr; 531 std::vector<std::unique_ptr<HloComputation>> computations_; 532 533 // Random number generator engine to use when generating random numbers per 534 // HloModule compilation. 535 // TODO(b/25995601): Replace with better seed setting or dev/random for 536 // where we don't need deterministic execution. 537 mutable std::mt19937_64 rng_{42}; 538 mutable absl::Mutex rng_mutex_; 539 540 // Unique name generator for computation and instruction names, which are 541 // unique per module. 542 NameUniquer computation_name_uniquer_{/*separator=*/"."}; 543 NameUniquer instruction_name_uniquer_{/*separator=*/"."}; 544 int next_unique_id_ = 0; 545 546 // Used to keep track of the next unique module id that should be assigned. 547 static std::atomic<int> next_unique_module_id_; 548 // A unique id to label modules with. 549 int unique_id_; 550 551 // The HloSchedule of the module. The schedule if it exists contains a 552 // sequential order of instructions for each non-fusion computation in the 553 // module. 554 std::optional<HloSchedule> schedule_; 555 556 // alias_config indicates the alias information of input/output buffers that 557 // are expected from the module. 558 HloInputOutputAliasConfig input_output_alias_config_; 559 560 // Bindings for dynamic parameter mapping. 561 DynamicParameterBinding dynamic_parameter_binding_; 562 563 // The HLO shardings of the entry computation's parameters for 564 // SPMD-partitioned programs. 565 std::optional<std::vector<HloSharding>> spmd_parameters_shardings_; 566 567 // The HLO sharding of the entry computation's output (root) for 568 // SPMD-partitioned programs. 569 std::optional<HloSharding> spmd_output_sharding_; 570 571 // Arguments to be prefetched across programs. 572 std::vector<std::pair<int64_t, ShapeIndex>> cross_program_prefetches_; 573 574 // Metadata for this module, such as its canonical id and the HLO passes run. 575 HloModuleMetadata metadata_; 576 577 // True if the module contains dynamic computation. 578 bool is_dynamic_ = false; 579 580 // Optional compilation profile handle. 581 int64_t profile_version_ = 0; 582 583 // An array of ProfileInfo specifying what optimization profiles this module 584 // contains, along with the relative speedups. 585 std::vector<HloModuleProto::ProfileInfo> profile_info_list_; 586 587 // Relative speedup of best config compared to default config. 588 double relative_speedup_; 589 590 // The unoptimized module fingerprint. 591 std::string autofdo_fingerprint_; 592 593 bool use_auto_spmd_partitioning_ = false; 594 595 // Layout canonicalization callback, used only when 596 // use_auto_spmd_partitioning_ = true. 597 LayoutCanonicalizationCallback layout_canonicalization_callback_; 598 599 // Compilation environments (protos that carry command line flags and 600 // environment variables). 601 std::unique_ptr<CompilationEnvironments> comp_envs_ = 602 std::make_unique<CompilationEnvironments>(); 603 }; 604 605 } // namespace xla 606 607 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 608