1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 18 19 #include <atomic> 20 #include <list> 21 #include <memory> 22 #include <random> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "absl/strings/string_view.h" 28 #include "absl/types/optional.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/iterator_util.h" 31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" 32 #include "tensorflow/compiler/xla/service/hlo.pb.h" 33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 34 #include "tensorflow/compiler/xla/service/hlo_computation.h" 35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" 36 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 37 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 38 #include "tensorflow/compiler/xla/service/hlo_module_metadata.h" 39 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 40 #include "tensorflow/compiler/xla/service/name_uniquer.h" 41 #include "tensorflow/compiler/xla/types.h" 42 #include "tensorflow/core/lib/gtl/iterator_range.h" 43 #include "tensorflow/core/platform/logging.h" 44 #include "tensorflow/core/platform/mutex.h" 45 46 namespace xla { 47 48 // Describes a compilation unit at the HLO level. 49 // 50 // HloModule is the top-level unit in the HLO IR. It corresponds to a whole 51 // "program". Running a module, from beginning to end, is the only way to run 52 // an XLA program. 53 // 54 // A module contains one "entry computation"; this HloComputation is like main() 55 // in a C program. The result of running the module is the result of running 56 // this computation. 57 // 58 // A module also contains some number of "nested computations". Each nested 59 // computation is attached to an HloInstruction within some other computation. 60 // The meaning of the nested computation depends on the instruction it's 61 // attached to. 62 class HloModule { 63 public: 64 // Constructor without a versioned computation handle. This constructor should 65 // only be used for HloModules used outside of the XLA service (eg 66 // tests). The versioned handle is used by the service in the compilation 67 // cache. A default configuration is created for this module. 68 explicit HloModule(const string& name, HloModuleConfig config); ~HloModule()69 virtual ~HloModule() {} 70 71 // Adds an entry computation to the module. A module can only have one entry 72 // computation. Returns a pointer to the newly added computation. 73 HloComputation* AddEntryComputation( 74 std::unique_ptr<HloComputation> computation); 75 76 // Same as the AddEntryComputation function above but the module's 77 // entry_computation_layout is updated to match the layout of the new entry 78 // computation. 79 HloComputation* AddEntryComputationWithLayouts( 80 std::unique_ptr<HloComputation> computation); 81 82 // Replaces the current entry computation with another computation. 83 // The new entry computation must be a computation that is already in the 84 // module. 85 void ReplaceEntryComputation(HloComputation* entry_computation); 86 87 // Adds an embedded computation to the module. 88 HloComputation* AddEmbeddedComputation( 89 std::unique_ptr<HloComputation> computation); 90 91 // Removes an embedded computation. 92 Status RemoveEmbeddedComputation(HloComputation* to_remove); 93 94 // Removes unused computations. 95 Status RemoveUnusedComputations(); 96 97 // Replaces all uses of computations that are keys of 'replacements' with 98 // the corresponding values in 'replacements'. Replaces the entry computation, 99 // if applicable. 100 // 101 // This function iterates over all instructions in the module to find 102 // computations to replace. We could speed it up by keeping track of users of 103 // computations. 104 void ReplaceComputations( 105 const std::unordered_map<HloComputation*, HloComputation*>& replacements); 106 name()107 const string& name() const { return name_; } set_name(string name)108 void set_name(string name) { name_ = std::move(name); } 109 110 // Returns a deep copy of this module including all computations. 111 std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; 112 std::unique_ptr<HloModule> Clone(const HloModuleConfig& config, 113 const string& suffix = "clone") const; 114 115 // Performs a deep clone of the computation, by recursively cloning all 116 // the called computations as well. If the clone context is specified, it 117 // will be populated with the cloned object mappings. 118 HloComputation* DeepCloneComputation(HloComputation* computation, 119 HloCloneContext* context = nullptr); 120 121 // Return a pointer to the entry computation of the module. entry_computation()122 HloComputation* entry_computation() const { 123 CHECK_NE(nullptr, entry_computation_); 124 return entry_computation_; 125 } 126 has_entry_computation()127 bool has_entry_computation() const { return entry_computation_ != nullptr; } 128 129 // Returns the root instruction shape of entry computation. 130 // 131 // Precondition: entry_computation_ is not nullptr. result_shape()132 const Shape& result_shape() const { 133 CHECK_NE(nullptr, entry_computation_); 134 return entry_computation()->root_instruction()->shape(); 135 } 136 137 // Creates the ComputationLayout which describes the current status of the HLO 138 // module entry computation. compute_computation_layout()139 ComputationLayout compute_computation_layout() const { 140 return ComputationLayout(entry_computation()->ComputeProgramShape(), 141 /*ignore_layouts=*/false); 142 } 143 mutable_entry_computation_layout()144 ComputationLayout* mutable_entry_computation_layout() { 145 return config_.mutable_entry_computation_layout(); 146 } 147 entry_computation_layout()148 const ComputationLayout& entry_computation_layout() const { 149 return config_.entry_computation_layout(); 150 } 151 152 // Generates a hash value of an HLO module. Hash considers 153 // information on opcode, shape, operands, and typically a root instruction. 154 // This function returns the same hash value for equivalent HLO modules, 155 // with respect to HloInstruction::Identical() method. 156 uint64 Hash() const; 157 158 // Gets the computations in this module. 159 // 160 // Returns a view of HloComputation*s, so you can iterate over this in the 161 // natural way: 162 // 163 // for (HloComputation* c : module->computations()) { ... } 164 // 165 tensorflow::gtl::iterator_range<UnwrappingIterator< 166 std::vector<std::unique_ptr<HloComputation>>::const_iterator>> computations()167 computations() const { 168 return {MakeUnwrappingIterator(computations_.begin()), 169 MakeUnwrappingIterator(computations_.end())}; 170 } 171 tensorflow::gtl::iterator_range<UnwrappingIterator< 172 std::vector<std::unique_ptr<HloComputation>>::iterator>> computations()173 computations() { 174 return {MakeUnwrappingIterator(computations_.begin()), 175 MakeUnwrappingIterator(computations_.end())}; 176 } 177 178 // Returns the computation in this module that has the name `name`. Returns 179 // null if there is no such computation. 180 HloComputation* GetComputationWithName(absl::string_view name); 181 182 // Gets the number of computations in this module. computation_count()183 int64 computation_count() const { return computations_.size(); } 184 185 // Returns the mutable computation for the given index. mutable_computation(int64 idx)186 HloComputation* mutable_computation(int64 idx) { 187 CHECK(idx >= 0 && idx < computations_.size()); 188 return computations_[idx].get(); 189 } 190 191 // Gets the number of instructions in this module. 192 int64 instruction_count() const; 193 194 // Deallocate removed instructions in each computation. Cleanup()195 void Cleanup() { 196 for (auto& comp : computations_) { 197 comp->Cleanup(); 198 } 199 } 200 201 // Compute and return a post order of all computations in the module. The sort 202 // is defined like so: if computation A has an instruction which calls 203 // computation B, then A will appear after B in the sort. 204 std::vector<HloComputation*> MakeComputationPostOrder() const; 205 206 // Same as MakeComputationPostOrder() but only returns the computations 207 // that are also found in the passed in allowList 208 std::vector<HloComputation*> MakeComputationPostOrder( 209 const absl::flat_hash_set<HloComputation*>& allow_list) const; 210 211 // Same as MakeComputationPostOrder() but sorting the computations by their 212 // contents. The order is longer post order. 213 std::vector<HloComputation*> MakeComputationSorted() const; 214 215 // Gets the computations in this module which aren't for fusion nodes. 216 // 217 // Postcondition: All computations in the returned list have 218 // !IsFusionComputation(). 219 // 220 // Note: Callers can and do rely on the return value here being a *snapshot* 221 // of the module's non-fusion computations -- that is, it's OK to add or 222 // remove computations from a module while iterating over 223 // MakeNonfusionComputations(). 224 std::vector<HloComputation*> MakeNonfusionComputations() const; 225 226 // Same as MakeNonfusionComputations() but sorting computations by content. 227 std::vector<HloComputation*> MakeNonfusionComputationsSorted() const; 228 config()229 const HloModuleConfig& config() const { return config_; } set_config(const HloModuleConfig & config)230 void set_config(const HloModuleConfig& config) { config_ = config; } 231 232 // Return a string representation of the module. 233 // 234 // (We express the default options using an overload rather than a default 235 // param because gdb ignores default params, but does resolve overloads.) ToString()236 string ToString() const { return ToString(HloPrintOptions()); } 237 string ToString(const HloPrintOptions& options) const; 238 239 // Convert an HloModule to or from a proto. 240 HloModuleProto ToProto() const; 241 static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( 242 const HloModuleProto& proto, const HloModuleConfig& module_config, 243 bool prohibit_empty_literal = true); 244 245 // Creates and returns an HloModuleConfig with an appropriate program shape 246 // for the HLO module in the given proto. 247 static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( 248 const HloModuleProto& module, const DebugOptions& debug_options, 249 const ExecutionOptions* execution_options = nullptr); 250 251 // Creates and returns an HloModuleConfig with an appropriate program shape 252 // for the HLO module in the given proto. 253 static StatusOr<HloModuleConfig> CreateModuleConfigFromShape( 254 const ProgramShape& program_shape, const DebugOptions& debug_options, 255 const ExecutionOptions* execution_options = nullptr); 256 257 // Outlines the given expression from the given computation. 258 // instructions_to_outline contains the instructions that form the expression. 259 // 260 // Precondition: instructions in instructions_to_outline are in topological 261 // order (root of outlined instructions last). TODO(jingyue): takes a set of 262 // instructions and topologically sorts them. 263 HloInstruction* OutlineExpressionFromComputation( 264 absl::Span<HloInstruction* const> instructions_to_outline, 265 const string& outlined_computation_name, HloComputation* computation); 266 267 // Returns a randomly generated uint64. 268 uint64 RandomNew64() const; 269 270 // Returns the NameUniquer for uniquing instruction names in this module. instruction_name_uniquer()271 NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } 272 273 // Assign a new unique dense id for an instruction NewUniqueInstructionId()274 int NewUniqueInstructionId() { 275 int result = next_unique_id_; 276 next_unique_id_++; 277 return result; 278 } 279 280 // input_output_alias_config indicates the list of aliased buffers that are 281 // expected from the module. input_output_alias_config()282 HloInputOutputAliasConfig& input_output_alias_config() { 283 return input_output_alias_config_; 284 } input_output_alias_config()285 const HloInputOutputAliasConfig& input_output_alias_config() const { 286 return input_output_alias_config_; 287 } 288 289 // DynamicParameterBinding holds the list of bindings that indicates which 290 // parameter dimensions are dynamic and which parameters represent their 291 // runtime value. dynamic_parameter_binding()292 DynamicParameterBinding& dynamic_parameter_binding() { 293 return dynamic_parameter_binding_; 294 } dynamic_parameter_binding()295 const DynamicParameterBinding& dynamic_parameter_binding() const { 296 return dynamic_parameter_binding_; 297 } 298 299 // Returns an id that is unique to this module across all modules created over 300 // the lifetime of this process. unique_id()301 int unique_id() const { return unique_id_; } 302 303 // Sets the schedule of the module to the given schedule. 304 Status set_schedule(HloSchedule schedule); 305 306 // Clears the schedule of the module. clear_schedule()307 void clear_schedule() { schedule_.reset(); } 308 309 // Returns true if the module has a schedule set. has_schedule()310 bool has_schedule() const { return schedule_.has_value(); } 311 312 // Returns the schedule of the module. CHECK fails if no schedule is set. schedule()313 const HloSchedule& schedule() const { return *schedule_; } schedule()314 HloSchedule& schedule() { return *schedule_; } 315 AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)316 HloComputation* AddComputationAndUnifyNamesAndIds( 317 std::unique_ptr<HloComputation> computation, bool is_entry) { 318 computation->ClearUniqueIdInternal(); 319 for (auto* instruction : computation->instructions()) { 320 instruction->ClearUniqueIdInternal(); 321 } 322 return AddComputationInternal(std::move(computation), is_entry, 323 /*uniquify_identifiers=*/true, 324 /*preserve_entry_layouts=*/true); 325 } 326 327 Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; 328 329 // Checks if this config has a list of entry parameters' HLO shardings for 330 // SPMD. has_spmd_parameters_shardings()331 bool has_spmd_parameters_shardings() const { 332 return spmd_parameters_shardings_.has_value(); 333 } 334 335 // Getter and setter for the list of entry parameters' HLO shardings for SPMD. spmd_parameters_shardings()336 const std::vector<HloSharding>& spmd_parameters_shardings() const { 337 CHECK(spmd_parameters_shardings_.has_value()); 338 return *spmd_parameters_shardings_; 339 } set_spmd_parameters_shardings(const std::vector<HloSharding> & shardings)340 void set_spmd_parameters_shardings( 341 const std::vector<HloSharding>& shardings) { 342 spmd_parameters_shardings_ = shardings; 343 } 344 345 // Checks if this config has the entry computation output's HLO sharding for 346 // SPMD. has_spmd_output_sharding()347 bool has_spmd_output_sharding() const { 348 return spmd_output_sharding_.has_value(); 349 } 350 351 // Getter and setter for the entry computation output's HLO shardings for 352 // SPMD. spmd_output_sharding()353 const HloSharding& spmd_output_sharding() const { 354 CHECK(spmd_output_sharding_.has_value()); 355 return *spmd_output_sharding_; 356 } set_spmd_output_sharding(const HloSharding & sharding)357 void set_spmd_output_sharding(const HloSharding& sharding) { 358 spmd_output_sharding_ = sharding; 359 } 360 361 // Add a program argument to be prefetched across programs. AddCrossProgramPrefetch(int64 parameter,const ShapeIndex & index)362 void AddCrossProgramPrefetch(int64 parameter, const ShapeIndex& index) { 363 cross_program_prefetches_.emplace_back(parameter, index); 364 } 365 366 // Get the list of program arguments to be prefetch across programs. CrossProgramPrefetches()367 const absl::Span<const std::pair<int64, ShapeIndex>> CrossProgramPrefetches() 368 const { 369 return cross_program_prefetches_; 370 } 371 metadata()372 const HloModuleMetadata& metadata() const { return metadata_; } metadata()373 HloModuleMetadata* metadata() { return &metadata_; } 374 375 // Moves (not copies) metadata from this HloModule to `module`. To be used 376 // in cases like HloModuleGroup::ReplaceModule when metadata should be 377 // transferred out of a module before it's destroyed. MoveMetadataToModule(HloModule * module)378 void MoveMetadataToModule(HloModule* module) { 379 module->metadata_ = std::move(metadata_); 380 } 381 382 private: 383 HloComputation* AddComputationInternal( 384 std::unique_ptr<HloComputation> computation, bool is_entry, 385 bool uniquify_identifiers, bool preserve_entry_layouts); 386 387 string name_; 388 HloModuleConfig config_; 389 HloComputation* entry_computation_ = nullptr; 390 std::vector<std::unique_ptr<HloComputation>> computations_; 391 392 // Random number generator engine to use when generating random numbers per 393 // HloModule compilation. 394 // TODO(b/25995601): Replace with better seed setting or dev/random for 395 // where we don't need deterministic execution. 396 mutable std::mt19937_64 rng_{42}; 397 mutable tensorflow::mutex rng_mutex_; 398 399 // Unique name generator for computation and instruction names, which are 400 // unique per module. 401 NameUniquer computation_name_uniquer_{/*separator=*/"."}; 402 NameUniquer instruction_name_uniquer_{/*separator=*/"."}; 403 int next_unique_id_ = 0; 404 405 // Used to keep track of the next unique module id that should be assigned. 406 static std::atomic<int> next_unique_module_id_; 407 // A unique id to label modules with. 408 int unique_id_; 409 410 // The HloSchedule of the module. The schedule if it exists contains a 411 // sequential order of instructions for each non-fusion computation in the 412 // module. 413 absl::optional<HloSchedule> schedule_; 414 415 // alias_config indicates the alias information of input/output buffers that 416 // are expected from the module. 417 HloInputOutputAliasConfig input_output_alias_config_; 418 419 // Bindings for dynamic parameter mapping. 420 DynamicParameterBinding dynamic_parameter_binding_; 421 422 // The HLO shardings of the entry computation's parameters for 423 // SPMD-partitioned programs. 424 absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_; 425 426 // The HLO sharding of the entry computation's output (root) for 427 // SPMD-partitioned programs. 428 absl::optional<HloSharding> spmd_output_sharding_; 429 430 // Arguments to be prefetched across programs. 431 std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_; 432 433 // Metadata for this module, such as its canonical id and the HLO passes run. 434 HloModuleMetadata metadata_; 435 }; 436 437 } // namespace xla 438 439 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 440