1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 18 19 #include <atomic> 20 #include <list> 21 #include <memory> 22 #include <random> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "absl/strings/string_view.h" 28 #include "absl/types/optional.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/iterator_util.h" 31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" 32 #include "tensorflow/compiler/xla/service/hlo.pb.h" 33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 34 #include "tensorflow/compiler/xla/service/hlo_computation.h" 35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" 36 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 37 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 38 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 39 #include "tensorflow/compiler/xla/service/name_uniquer.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/core/lib/gtl/iterator_range.h" 42 #include "tensorflow/core/platform/logging.h" 43 #include "tensorflow/core/platform/mutex.h" 44 45 namespace xla { 46 47 // Describes a compilation unit at the HLO level. 48 // 49 // HloModule is the top-level unit in the HLO IR. It corresponds to a whole 50 // "program". Running a module, from beginning to end, is the only way to run 51 // an XLA program. 52 // 53 // A module contains one "entry computation"; this HloComputation is like main() 54 // in a C program. The result of running the module is the result of running 55 // this computation. 56 // 57 // A module also contains some number of "nested computations". Each nested 58 // computation is attached to an HloInstruction within some other computation. 59 // The meaning of the nested computation depends on the instruction it's 60 // attached to. 61 class HloModule { 62 public: 63 // Constructor without a versioned computation handle. This constructor should 64 // only be used for HloModules used outside of the XLA service (eg 65 // tests). The versioned handle is used by the service in the compilation 66 // cache. A default configuration is created for this module. 67 explicit HloModule(const string& name, HloModuleConfig config); ~HloModule()68 virtual ~HloModule() {} 69 70 // Adds an entry computation to the module. A module can only have one entry 71 // computation. Returns a pointer to the newly added computation. 72 HloComputation* AddEntryComputation( 73 std::unique_ptr<HloComputation> computation); 74 75 // Replaces the current entry computation with another computation. 76 // The new entry computation must be a computation that is already in the 77 // module. 78 void ReplaceEntryComputation(HloComputation* entry_computation); 79 80 // Adds an embedded computation to the module. 81 HloComputation* AddEmbeddedComputation( 82 std::unique_ptr<HloComputation> computation); 83 84 // Removes an embedded computation. 85 Status RemoveEmbeddedComputation(HloComputation* to_remove); 86 87 // Removes unused computations. 88 Status RemoveUnusedComputations(); 89 90 // Replaces all uses of computations that are keys of 'replacements' with 91 // the corresponding values in 'replacements'. Replaces the entry computation, 92 // if applicable. 93 // 94 // This function iterates over all instructions in the module to find 95 // computations to replace. We could speed it up by keeping track of users of 96 // computations. 97 void ReplaceComputations( 98 const std::unordered_map<HloComputation*, HloComputation*>& replacements); 99 name()100 const string& name() const { return name_; } set_name(string name)101 void set_name(string name) { name_ = std::move(name); } 102 103 // Returns a deep copy of this module including all computations. 104 std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; 105 std::unique_ptr<HloModule> Clone(const HloModuleConfig& config, 106 const string& suffix = "clone") const; 107 108 // Performs a deep clone of the computation, by recursively cloning all 109 // the called computations as well. If the clone context is specified, it 110 // will be populated with the cloned object mappings. 111 HloComputation* DeepCloneComputation(HloComputation* computation, 112 HloCloneContext* context = nullptr); 113 114 // Return a pointer to the entry computation of the module. entry_computation()115 HloComputation* entry_computation() const { 116 CHECK_NE(nullptr, entry_computation_); 117 return entry_computation_; 118 } 119 has_entry_computation()120 bool has_entry_computation() const { return entry_computation_ != nullptr; } 121 122 // Returns the root instruction shape of entry computation. 123 // 124 // Precondition: entry_computation_ is not nullptr. result_shape()125 const Shape& result_shape() const { 126 CHECK_NE(nullptr, entry_computation_); 127 return entry_computation()->root_instruction()->shape(); 128 } 129 130 // Creates the ComputationLayout which describes the current status of the HLO 131 // module entry computation. compute_computation_layout()132 ComputationLayout compute_computation_layout() const { 133 return ComputationLayout(entry_computation()->ComputeProgramShape(), 134 /*ignore_layouts=*/false); 135 } 136 mutable_entry_computation_layout()137 ComputationLayout* mutable_entry_computation_layout() { 138 return config_.mutable_entry_computation_layout(); 139 } 140 entry_computation_layout()141 const ComputationLayout& entry_computation_layout() const { 142 return config_.entry_computation_layout(); 143 } 144 145 // Generates a hash value of an HLO module. Hash considers 146 // information on opcode, shape, operands, and typically a root instruction. 147 // This function returns the same hash value for equivalent HLO modules, 148 // with respect to HloInstruction::Identical() method. 149 uint64 Hash() const; 150 151 // Gets the computations in this module. 152 // 153 // Returns a view of HloComputation*s, so you can iterate over this in the 154 // natural way: 155 // 156 // for (HloComputation* c : module->computations()) { ... } 157 // 158 tensorflow::gtl::iterator_range<UnwrappingIterator< 159 std::vector<std::unique_ptr<HloComputation>>::const_iterator>> computations()160 computations() const { 161 return {MakeUnwrappingIterator(computations_.begin()), 162 MakeUnwrappingIterator(computations_.end())}; 163 } 164 tensorflow::gtl::iterator_range<UnwrappingIterator< 165 std::vector<std::unique_ptr<HloComputation>>::iterator>> computations()166 computations() { 167 return {MakeUnwrappingIterator(computations_.begin()), 168 MakeUnwrappingIterator(computations_.end())}; 169 } 170 171 // Returns the computation in this module that has the name `name`. Returns 172 // null if there is no such computation. 173 HloComputation* GetComputationWithName(absl::string_view name); 174 175 // Gets the number of computations in this module. computation_count()176 int64 computation_count() const { return computations_.size(); } 177 178 // Returns the mutable computation for the given index. mutable_computation(int64 idx)179 HloComputation* mutable_computation(int64 idx) { 180 CHECK(idx >= 0 && idx < computations_.size()); 181 return computations_[idx].get(); 182 } 183 184 // Gets the number of instructions in this module. 185 int64 instruction_count() const; 186 187 // Compute and return a post order of all computations in the module. The sort 188 // is defined like so: if computation A has an instruction which calls 189 // computation B, then A will appear after B in the sort. 190 std::vector<HloComputation*> MakeComputationPostOrder() const; 191 192 // Gets the computations in this module which aren't for fusion nodes. 193 // 194 // Postcondition: All computations in the returned list have 195 // !IsFusionComputation(). 196 // 197 // Note: Callers can and do rely on the return value here being a *snapshot* 198 // of the module's non-fusion computations -- that is, it's OK to add or 199 // remove computations from a module while iterating over 200 // MakeNonfusionComputations(). 201 std::vector<HloComputation*> MakeNonfusionComputations() const; 202 203 // Same as MakeNonfusionComputations() but sorting the computations by names. 204 std::vector<HloComputation*> MakeNonfusionComputationsSorted() const; 205 config()206 const HloModuleConfig& config() const { return config_; } set_config(const HloModuleConfig & config)207 void set_config(const HloModuleConfig& config) { config_ = config; } 208 209 // Return a string representation of the module. 210 // 211 // (We express the default options using an overload rather than a default 212 // param because gdb ignores default params, but does resolve overloads.) ToString()213 string ToString() const { return ToString(HloPrintOptions()); } 214 string ToString(const HloPrintOptions& options) const; 215 216 // Convert an HloModule to or from a proto. 217 HloModuleProto ToProto() const; 218 static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( 219 const HloModuleProto& proto, const HloModuleConfig& module_config, 220 bool prohibit_empty_literal = true); 221 222 // Creates and returns an HloModuleConfig with an appropriate program shape 223 // for the HLO module in the given proto. 224 static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( 225 const HloModuleProto& module, const DebugOptions& debug_options, 226 const ExecutionOptions* execution_options = nullptr); 227 228 // Creates and returns an HloModuleConfig with an appropriate program shape 229 // for the HLO module in the given proto. 230 static StatusOr<HloModuleConfig> CreateModuleConfigFromShape( 231 const ProgramShape& program_shape, const DebugOptions& debug_options, 232 const ExecutionOptions* execution_options = nullptr); 233 234 // Outlines the given expression from the given computation. 235 // instructions_to_outline contains the instructions that form the expression. 236 // 237 // Precondition: instructions in instructions_to_outline are in topological 238 // order (root of outlined instructions last). TODO(jingyue): takes a set of 239 // instructions and topologically sorts them. 240 HloInstruction* OutlineExpressionFromComputation( 241 absl::Span<HloInstruction* const> instructions_to_outline, 242 const string& outlined_computation_name, HloComputation* computation); 243 244 // Returns a randomly generated uint64. 245 uint64 RandomNew64() const; 246 247 // Returns the NameUniquer for uniquing instruction names in this module. instruction_name_uniquer()248 NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } 249 250 // Assign a new unique dense id for an instruction NewUniqueInstructionId()251 int NewUniqueInstructionId() { 252 int result = next_unique_id_; 253 next_unique_id_++; 254 return result; 255 } 256 257 // input_output_alias_config indicates the list of aliased buffers that are 258 // expected from the module. input_output_alias_config()259 HloInputOutputAliasConfig& input_output_alias_config() { 260 return input_output_alias_config_; 261 } input_output_alias_config()262 const HloInputOutputAliasConfig& input_output_alias_config() const { 263 return input_output_alias_config_; 264 } 265 266 // DynamicParameterBinding holds the list of bindings that indicates which 267 // parameter dimensions are dynamic and which parameters represent their 268 // runtime value. dynamic_parameter_binding()269 DynamicParameterBinding& dynamic_parameter_binding() { 270 return dynamic_parameter_binding_; 271 } dynamic_parameter_binding()272 const DynamicParameterBinding& dynamic_parameter_binding() const { 273 return dynamic_parameter_binding_; 274 } 275 276 // Returns an id that is unique to this module across all modules created over 277 // the lifetime of this process. unique_id()278 int unique_id() const { return unique_id_; } 279 280 // Sets the schedule of the module to the given schedule. 281 Status set_schedule(HloSchedule schedule); 282 283 // Clears the schedule of the module. clear_schedule()284 void clear_schedule() { schedule_.reset(); } 285 286 // Returns true if the module has a schedule set. has_schedule()287 bool has_schedule() const { return schedule_.has_value(); } 288 289 // Returns the schedule of the module. CHECK fails if no schedule is set. schedule()290 const HloSchedule& schedule() const { return *schedule_; } schedule()291 HloSchedule& schedule() { return *schedule_; } 292 AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)293 HloComputation* AddComputationAndUnifyNamesAndIds( 294 std::unique_ptr<HloComputation> computation, bool is_entry) { 295 computation->ClearUniqueIdInternal(); 296 for (auto* instruction : computation->instructions()) { 297 instruction->ClearUniqueIdInternal(); 298 } 299 return AddComputationInternal(std::move(computation), is_entry, 300 /*uniquify_identifiers=*/true); 301 } 302 303 Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; 304 305 // Checks if this config has a list of entry parameters' HLO shardings for 306 // SPMD. has_spmd_parameters_shardings()307 bool has_spmd_parameters_shardings() const { 308 return spmd_parameters_shardings_.has_value(); 309 } 310 311 // Getter and setter for the list of entry parameters' HLO shardings for SPMD. spmd_parameters_shardings()312 const std::vector<HloSharding>& spmd_parameters_shardings() const { 313 CHECK(spmd_parameters_shardings_.has_value()); 314 return *spmd_parameters_shardings_; 315 } set_spmd_parameters_shardings(const std::vector<HloSharding> & shardings)316 void set_spmd_parameters_shardings( 317 const std::vector<HloSharding>& shardings) { 318 spmd_parameters_shardings_ = shardings; 319 } 320 321 // Checks if this config has the entry computation output's HLO sharding for 322 // SPMD. has_spmd_output_sharding()323 bool has_spmd_output_sharding() const { 324 return spmd_output_sharding_.has_value(); 325 } 326 327 // Getter and setter for the entry computation output's HLO shardings for 328 // SPMD. spmd_output_sharding()329 const HloSharding& spmd_output_sharding() const { 330 CHECK(spmd_output_sharding_.has_value()); 331 return *spmd_output_sharding_; 332 } set_spmd_output_sharding(const HloSharding & sharding)333 void set_spmd_output_sharding(const HloSharding& sharding) { 334 spmd_output_sharding_ = sharding; 335 } 336 337 private: 338 HloComputation* AddComputationInternal( 339 std::unique_ptr<HloComputation> computation, bool is_entry, 340 bool uniquify_identifiers); 341 342 // Same as MakeComputationPostOrder() but sorting the computations by their 343 // contents. 344 std::vector<HloComputation*> MakeComputationSortedByContent() const; 345 346 string name_; 347 HloModuleConfig config_; 348 HloComputation* entry_computation_ = nullptr; 349 std::vector<std::unique_ptr<HloComputation>> computations_; 350 351 // Random number generator engine to use when generating random numbers per 352 // HloModule compilation. 353 // TODO(b/25995601): Replace with better seed setting or dev/random for 354 // where we don't need deterministic execution. 355 mutable std::mt19937_64 rng_{42}; 356 mutable tensorflow::mutex rng_mutex_; 357 358 // Unique name generator for computation and instruction names, which are 359 // unique per module. 360 NameUniquer computation_name_uniquer_{/*separator=*/"."}; 361 NameUniquer instruction_name_uniquer_{/*separator=*/"."}; 362 int next_unique_id_ = 0; 363 364 // Used to keep track of the next unique module id that should be assigned. 365 static std::atomic<int> next_unique_module_id_; 366 // A unique id to label modules with. 367 int unique_id_; 368 369 // The HloSchedule of the module. The schedule if it exists contains a 370 // sequential order of instructions for each non-fusion computation in the 371 // module. 372 absl::optional<HloSchedule> schedule_; 373 374 // alias_config indicates the alias information of input/output buffers that 375 // are expected from the module. 376 HloInputOutputAliasConfig input_output_alias_config_; 377 378 // Bindings for dynamic parameter mapping. 379 DynamicParameterBinding dynamic_parameter_binding_; 380 381 // The HLO shardings of the entry computation's parameters for 382 // SPMD-partitioned programs. 383 absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_; 384 385 // The HLO sharding of the entry computation's output (root) for 386 // SPMD-partitioned programs. 387 absl::optional<HloSharding> spmd_output_sharding_; 388 }; 389 390 } // namespace xla 391 392 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 393