1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 18 19 #include <atomic> 20 #include <list> 21 #include <memory> 22 #include <random> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "absl/strings/string_view.h" 28 #include "absl/types/optional.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/iterator_util.h" 31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" 32 #include "tensorflow/compiler/xla/service/hlo.pb.h" 33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 34 #include "tensorflow/compiler/xla/service/hlo_computation.h" 35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" 36 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 37 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 38 #include "tensorflow/compiler/xla/service/hlo_module_metadata.h" 39 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 40 #include "tensorflow/compiler/xla/service/name_uniquer.h" 41 #include "tensorflow/compiler/xla/types.h" 42 #include "tensorflow/core/lib/gtl/iterator_range.h" 43 #include "tensorflow/core/platform/logging.h" 44 #include "tensorflow/core/platform/mutex.h" 45 46 namespace xla { 47 48 // Describes a compilation unit at the HLO level. 49 // 50 // HloModule is the top-level unit in the HLO IR. It corresponds to a whole 51 // "program". Running a module, from beginning to end, is the only way to run 52 // an XLA program. 53 // 54 // A module contains one "entry computation"; this HloComputation is like main() 55 // in a C program. The result of running the module is the result of running 56 // this computation. 57 // 58 // A module also contains some number of "nested computations". Each nested 59 // computation is attached to an HloInstruction within some other computation. 60 // The meaning of the nested computation depends on the instruction it's 61 // attached to. 62 class HloModule { 63 public: 64 // Constructor without a versioned computation handle. This constructor should 65 // only be used for HloModules used outside of the XLA service (eg 66 // tests). The versioned handle is used by the service in the compilation 67 // cache. A default configuration is created for this module. 68 explicit HloModule(const string& name, HloModuleConfig config); ~HloModule()69 virtual ~HloModule() {} 70 71 // Adds an entry computation to the module. A module can only have one entry 72 // computation. Returns a pointer to the newly added computation. 73 HloComputation* AddEntryComputation( 74 std::unique_ptr<HloComputation> computation); 75 76 // Same as the AddEntryComputation function above but the module's 77 // entry_computation_layout is updated to match the layout of the new entry 78 // computation. 79 HloComputation* AddEntryComputationWithLayouts( 80 std::unique_ptr<HloComputation> computation); 81 82 // Replaces the current entry computation with another computation. 83 // The new entry computation must be a computation that is already in the 84 // module. 85 void ReplaceEntryComputation(HloComputation* entry_computation); 86 87 // Adds an embedded computation to the module. 88 HloComputation* AddEmbeddedComputation( 89 std::unique_ptr<HloComputation> computation); 90 91 // Removes an embedded computation. 92 Status RemoveEmbeddedComputation(HloComputation* to_remove); 93 94 // Removes unused computations. 95 Status RemoveUnusedComputations(); 96 97 // Replaces all uses of computations that are keys of 'replacements' with 98 // the corresponding values in 'replacements'. Replaces the entry computation, 99 // if applicable. 100 // 101 // This function iterates over all instructions in the module to find 102 // computations to replace. We could speed it up by keeping track of users of 103 // computations. 104 void ReplaceComputations( 105 const absl::flat_hash_map<HloComputation*, HloComputation*>& 106 replacements); 107 name()108 const string& name() const { return name_; } set_name(string name)109 void set_name(string name) { name_ = std::move(name); } 110 111 // Returns a deep copy of this module including all computations. 112 std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; 113 std::unique_ptr<HloModule> Clone(const HloModuleConfig& config, 114 const string& suffix = "clone") const; 115 116 // Performs a deep clone of the computation, by recursively cloning all 117 // the called computations as well. If the clone context is specified, it 118 // will be populated with the cloned object mappings. 119 HloComputation* DeepCloneComputation(HloComputation* computation, 120 HloCloneContext* context = nullptr); 121 122 // Return a pointer to the entry computation of the module. entry_computation()123 HloComputation* entry_computation() const { 124 CHECK_NE(nullptr, entry_computation_); 125 return entry_computation_; 126 } 127 has_entry_computation()128 bool has_entry_computation() const { return entry_computation_ != nullptr; } 129 130 // Returns the root instruction shape of entry computation. 131 // 132 // Precondition: entry_computation_ is not nullptr. result_shape()133 const Shape& result_shape() const { 134 CHECK_NE(nullptr, entry_computation_); 135 return entry_computation()->root_instruction()->shape(); 136 } 137 138 // Creates the ComputationLayout which describes the current status of the HLO 139 // module entry computation. compute_computation_layout()140 ComputationLayout compute_computation_layout() const { 141 return ComputationLayout(entry_computation()->ComputeProgramShape(), 142 /*ignore_layouts=*/false); 143 } 144 mutable_entry_computation_layout()145 ComputationLayout* mutable_entry_computation_layout() { 146 return config_.mutable_entry_computation_layout(); 147 } 148 entry_computation_layout()149 const ComputationLayout& entry_computation_layout() const { 150 return config_.entry_computation_layout(); 151 } 152 153 // Generates a hash value of an HLO module. Hash considers 154 // information on opcode, shape, operands, and typically a root instruction. 155 // This function returns the same hash value for equivalent HLO modules, 156 // with respect to HloInstruction::Identical() method. 157 uint64 Hash() const; 158 159 // Gets the computations in this module. 160 // 161 // Returns a view of HloComputation*s, so you can iterate over this in the 162 // natural way: 163 // 164 // for (HloComputation* c : module->computations()) { ... } 165 // 166 tensorflow::gtl::iterator_range<UnwrappingIterator< 167 std::vector<std::unique_ptr<HloComputation>>::const_iterator>> computations()168 computations() const { 169 return {MakeUnwrappingIterator(computations_.begin()), 170 MakeUnwrappingIterator(computations_.end())}; 171 } 172 tensorflow::gtl::iterator_range<UnwrappingIterator< 173 std::vector<std::unique_ptr<HloComputation>>::iterator>> computations()174 computations() { 175 return {MakeUnwrappingIterator(computations_.begin()), 176 MakeUnwrappingIterator(computations_.end())}; 177 } 178 179 // Returns the computation in this module that has the name `name`. Returns 180 // null if there is no such computation. 181 HloComputation* GetComputationWithName(absl::string_view name); 182 183 // Gets the number of computations in this module. computation_count()184 int64 computation_count() const { return computations_.size(); } 185 186 // Returns the mutable computation for the given index. mutable_computation(int64_t idx)187 HloComputation* mutable_computation(int64_t idx) { 188 CHECK(idx >= 0 && idx < computations_.size()); 189 return computations_[idx].get(); 190 } 191 192 // Gets the number of instructions in this module. 193 int64 instruction_count() const; 194 195 // Deallocate removed instructions in each computation. Cleanup()196 void Cleanup() { 197 for (auto& comp : computations_) { 198 comp->Cleanup(); 199 } 200 } 201 202 // Compute and return a post order of all computations in the module. The sort 203 // is defined like so: if computation A has an instruction which calls 204 // computation B, then A will appear after B in the sort. 205 std::vector<HloComputation*> MakeComputationPostOrder() const; 206 207 // Same as MakeComputationPostOrder() but only returns the computations 208 // that are also found in the passed in allowList 209 std::vector<HloComputation*> MakeComputationPostOrder( 210 const absl::flat_hash_set<HloComputation*>& allow_list) const; 211 212 // Same as MakeComputationPostOrder() but sorting the computations by their 213 // contents. The order is longer post order. 214 std::vector<HloComputation*> MakeComputationSorted() const; 215 216 // Gets the computations in this module which aren't for fusion nodes. 217 // 218 // Postcondition: All computations in the returned list have 219 // !IsFusionComputation(). 220 // 221 // Note: Callers can and do rely on the return value here being a *snapshot* 222 // of the module's non-fusion computations -- that is, it's OK to add or 223 // remove computations from a module while iterating over 224 // MakeNonfusionComputations(). 225 std::vector<HloComputation*> MakeNonfusionComputations() const; 226 227 // Same as MakeNonfusionComputations() but sorting computations by content. 228 std::vector<HloComputation*> MakeNonfusionComputationsSorted() const; 229 config()230 HloModuleConfig& config() { return config_; } config()231 const HloModuleConfig& config() const { return config_; } set_config(const HloModuleConfig & config)232 void set_config(const HloModuleConfig& config) { config_ = config; } 233 is_dynamic()234 bool is_dynamic() const { return is_dynamic_; } set_is_dynamic(bool is_dynamic)235 void set_is_dynamic(bool is_dynamic) { is_dynamic_ = is_dynamic; } 236 237 // Return a string representation of the module. 238 // 239 // (We express the default options using an overload rather than a default 240 // param because gdb ignores default params, but does resolve overloads.) ToString()241 string ToString() const { return ToString(HloPrintOptions()); } 242 string ToString(const HloPrintOptions& options) const; 243 244 // Convert an HloModule to or from a proto. 245 HloModuleProto ToProto() const; 246 static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( 247 const HloModuleProto& proto, const HloModuleConfig& module_config, 248 bool prohibit_empty_literal = true); 249 250 // Creates and returns an HloModuleConfig with an appropriate program shape 251 // for the HLO module in the given proto. 252 static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( 253 const HloModuleProto& module, const DebugOptions& debug_options, 254 const ExecutionOptions* execution_options = nullptr); 255 256 // Creates and returns an HloModuleConfig with an appropriate program shape 257 // for the HLO module in the given proto. 258 static StatusOr<HloModuleConfig> CreateModuleConfigFromShape( 259 const ProgramShape& program_shape, const DebugOptions& debug_options, 260 const ExecutionOptions* execution_options = nullptr); 261 262 // Outlines the given expression from the given computation. 263 // instructions_to_outline contains the instructions that form the expression. 264 // 265 // Precondition: instructions in instructions_to_outline are in topological 266 // order (root of outlined instructions last). TODO(jingyue): takes a set of 267 // instructions and topologically sorts them. 268 HloInstruction* OutlineExpressionFromComputation( 269 absl::Span<HloInstruction* const> instructions_to_outline, 270 const string& outlined_computation_name, HloComputation* computation); 271 272 // Returns a randomly generated uint64. 273 uint64 RandomNew64() const; 274 275 // Returns the NameUniquer for uniquing instruction names in this module. instruction_name_uniquer()276 NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } 277 278 // Assign a new unique dense id for an instruction NewUniqueInstructionId()279 int NewUniqueInstructionId() { 280 int result = next_unique_id_; 281 next_unique_id_++; 282 return result; 283 } 284 285 // input_output_alias_config indicates the list of aliased buffers that are 286 // expected from the module. input_output_alias_config()287 HloInputOutputAliasConfig& input_output_alias_config() { 288 return input_output_alias_config_; 289 } input_output_alias_config()290 const HloInputOutputAliasConfig& input_output_alias_config() const { 291 return input_output_alias_config_; 292 } 293 294 // DynamicParameterBinding holds the list of bindings that indicates which 295 // parameter dimensions are dynamic and which parameters represent their 296 // runtime value. dynamic_parameter_binding()297 DynamicParameterBinding& dynamic_parameter_binding() { 298 return dynamic_parameter_binding_; 299 } dynamic_parameter_binding()300 const DynamicParameterBinding& dynamic_parameter_binding() const { 301 return dynamic_parameter_binding_; 302 } 303 304 // Returns an id that is unique to this module across all modules created over 305 // the lifetime of this process. unique_id()306 int unique_id() const { return unique_id_; } 307 308 // Sets the schedule of the module to the given schedule. 309 Status set_schedule(HloSchedule schedule); 310 311 // Clears the schedule of the module. clear_schedule()312 void clear_schedule() { schedule_.reset(); } 313 314 // Returns true if the module has a schedule set. has_schedule()315 bool has_schedule() const { return schedule_.has_value(); } 316 317 // Returns the schedule of the module. CHECK fails if no schedule is set. schedule()318 const HloSchedule& schedule() const { return *schedule_; } schedule()319 HloSchedule& schedule() { return *schedule_; } 320 AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)321 HloComputation* AddComputationAndUnifyNamesAndIds( 322 std::unique_ptr<HloComputation> computation, bool is_entry) { 323 computation->ClearUniqueIdInternal(); 324 for (auto* instruction : computation->instructions()) { 325 instruction->ClearUniqueIdInternal(); 326 } 327 return AddComputationInternal(std::move(computation), is_entry, 328 /*uniquify_identifiers=*/true, 329 /*preserve_entry_layouts=*/true); 330 } 331 332 Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; 333 334 // Checks if this config has a list of entry parameters' HLO shardings for 335 // SPMD. has_spmd_parameters_shardings()336 bool has_spmd_parameters_shardings() const { 337 return spmd_parameters_shardings_.has_value(); 338 } 339 340 // Getter and setter for the list of entry parameters' HLO shardings for SPMD. spmd_parameters_shardings()341 const std::vector<HloSharding>& spmd_parameters_shardings() const { 342 CHECK(spmd_parameters_shardings_.has_value()); 343 return *spmd_parameters_shardings_; 344 } set_spmd_parameters_shardings(const std::vector<HloSharding> & shardings)345 void set_spmd_parameters_shardings( 346 const std::vector<HloSharding>& shardings) { 347 spmd_parameters_shardings_ = shardings; 348 } 349 350 // Checks if this config has the entry computation output's HLO sharding for 351 // SPMD. has_spmd_output_sharding()352 bool has_spmd_output_sharding() const { 353 return spmd_output_sharding_.has_value(); 354 } 355 356 // Getter and setter for the entry computation output's HLO shardings for 357 // SPMD. spmd_output_sharding()358 const HloSharding& spmd_output_sharding() const { 359 CHECK(spmd_output_sharding_.has_value()); 360 return *spmd_output_sharding_; 361 } set_spmd_output_sharding(const HloSharding & sharding)362 void set_spmd_output_sharding(const HloSharding& sharding) { 363 spmd_output_sharding_ = sharding; 364 } 365 366 // Add a program argument to be prefetched across programs. AddCrossProgramPrefetch(int64_t parameter,const ShapeIndex & index)367 void AddCrossProgramPrefetch(int64_t parameter, const ShapeIndex& index) { 368 cross_program_prefetches_.emplace_back(parameter, index); 369 } 370 371 // Get the list of program arguments to be prefetch across programs. CrossProgramPrefetches()372 const absl::Span<const std::pair<int64, ShapeIndex>> CrossProgramPrefetches() 373 const { 374 return cross_program_prefetches_; 375 } 376 metadata()377 const HloModuleMetadata& metadata() const { return metadata_; } metadata()378 HloModuleMetadata* metadata() { return &metadata_; } 379 380 // Moves (not copies) metadata from this HloModule to `module`. To be used 381 // in cases like HloModuleGroup::ReplaceModule when metadata should be 382 // transferred out of a module before it's destroyed. MoveMetadataToModule(HloModule * module)383 void MoveMetadataToModule(HloModule* module) { 384 module->metadata_ = std::move(metadata_); 385 } 386 387 private: 388 HloComputation* AddComputationInternal( 389 std::unique_ptr<HloComputation> computation, bool is_entry, 390 bool uniquify_identifiers, bool preserve_entry_layouts); 391 392 string name_; 393 HloModuleConfig config_; 394 HloComputation* entry_computation_ = nullptr; 395 std::vector<std::unique_ptr<HloComputation>> computations_; 396 397 // Random number generator engine to use when generating random numbers per 398 // HloModule compilation. 399 // TODO(b/25995601): Replace with better seed setting or dev/random for 400 // where we don't need deterministic execution. 401 mutable std::mt19937_64 rng_{42}; 402 mutable tensorflow::mutex rng_mutex_; 403 404 // Unique name generator for computation and instruction names, which are 405 // unique per module. 406 NameUniquer computation_name_uniquer_{/*separator=*/"."}; 407 NameUniquer instruction_name_uniquer_{/*separator=*/"."}; 408 int next_unique_id_ = 0; 409 410 // Used to keep track of the next unique module id that should be assigned. 411 static std::atomic<int> next_unique_module_id_; 412 // A unique id to label modules with. 413 int unique_id_; 414 415 // The HloSchedule of the module. The schedule if it exists contains a 416 // sequential order of instructions for each non-fusion computation in the 417 // module. 418 absl::optional<HloSchedule> schedule_; 419 420 // alias_config indicates the alias information of input/output buffers that 421 // are expected from the module. 422 HloInputOutputAliasConfig input_output_alias_config_; 423 424 // Bindings for dynamic parameter mapping. 425 DynamicParameterBinding dynamic_parameter_binding_; 426 427 // The HLO shardings of the entry computation's parameters for 428 // SPMD-partitioned programs. 429 absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_; 430 431 // The HLO sharding of the entry computation's output (root) for 432 // SPMD-partitioned programs. 433 absl::optional<HloSharding> spmd_output_sharding_; 434 435 // Arguments to be prefetched across programs. 436 std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_; 437 438 // Metadata for this module, such as its canonical id and the HLO passes run. 439 HloModuleMetadata metadata_; 440 441 // True if the module contains dynamic computation. 442 bool is_dynamic_ = false; 443 }; 444 445 } // namespace xla 446 447 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 448