1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 18 19 #include <atomic> 20 #include <list> 21 #include <memory> 22 #include <random> 23 #include <string> 24 #include <unordered_map> 25 #include <vector> 26 27 #include "absl/strings/string_view.h" 28 #include "absl/types/optional.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/iterator_util.h" 31 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" 32 #include "tensorflow/compiler/xla/service/hlo.pb.h" 33 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 34 #include "tensorflow/compiler/xla/service/hlo_computation.h" 35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" 36 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 37 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 38 #include "tensorflow/compiler/xla/service/hlo_schedule.h" 39 #include "tensorflow/compiler/xla/service/name_uniquer.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/core/lib/gtl/iterator_range.h" 42 #include "tensorflow/core/platform/logging.h" 43 #include "tensorflow/core/platform/mutex.h" 44 45 namespace xla { 46 47 // Describes a compilation unit at the HLO level. 48 // 49 // HloModule is the top-level unit in the HLO IR. It corresponds to a whole 50 // "program". Running a module, from beginning to end, is the only way to run 51 // an XLA program. 52 // 53 // A module contains one "entry computation"; this HloComputation is like main() 54 // in a C program. The result of running the module is the result of running 55 // this computation. 56 // 57 // A module also contains some number of "nested computations". Each nested 58 // computation is attached to an HloInstruction within some other computation. 59 // The meaning of the nested computation depends on the instruction it's 60 // attached to. 61 class HloModule { 62 public: 63 // Constructor without a versioned computation handle. This constructor should 64 // only be used for HloModules used outside of the XLA service (eg 65 // tests). The versioned handle is used by the service in the compilation 66 // cache. A default configuration is created for this module. 67 explicit HloModule(const string& name, const HloModuleConfig& config); ~HloModule()68 virtual ~HloModule() {} 69 70 // Adds an entry computation to the module. A module can only have one entry 71 // computation. Returns a pointer to the newly added computation. 72 HloComputation* AddEntryComputation( 73 std::unique_ptr<HloComputation> computation); 74 75 // Adds an embedded computation to the module. 76 HloComputation* AddEmbeddedComputation( 77 std::unique_ptr<HloComputation> computation); 78 79 // Removes an embedded computation. 80 Status RemoveEmbeddedComputation(HloComputation* to_remove); 81 82 // Replaces all uses of computations that are keys of 'replacements' with 83 // the corresponding values in 'replacements'. Replaces the entry computation, 84 // if applicable. 85 // 86 // This function iterates over all instructions in the module to find 87 // computations to replace. We could speed it up by keeping track of users of 88 // computations. 89 void ReplaceComputations( 90 const std::unordered_map<HloComputation*, HloComputation*>& replacements); 91 name()92 const string& name() const { return name_; } set_name(string name)93 void set_name(string name) { name_ = std::move(name); } 94 95 // Returns a deep copy of this module including all computations. 96 std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; 97 std::unique_ptr<HloModule> Clone(const HloModuleConfig& config, 98 const string& suffix = "clone") const; 99 100 // Performs a deep clone of the computation, by recursively cloning all 101 // the called computations as well. If the clone context is specified, it 102 // will be populated with the cloned object mappings. 103 HloComputation* DeepCloneComputation(HloComputation* computation, 104 HloCloneContext* context = nullptr); 105 106 // Return a pointer to the entry computation of the module. entry_computation()107 HloComputation* entry_computation() const { 108 CHECK_NE(nullptr, entry_computation_); 109 return entry_computation_; 110 } 111 112 // Returns the root instruction shape of entry computation. 113 // 114 // Precondition: entry_computation_ is not nullptr. result_shape()115 const Shape& result_shape() const { 116 CHECK_NE(nullptr, entry_computation_); 117 return entry_computation()->root_instruction()->shape(); 118 } 119 120 // Creates the ComputationLayout which describes the current status of the HLO 121 // module entry computation. compute_computation_layout()122 ComputationLayout compute_computation_layout() const { 123 return ComputationLayout(entry_computation()->ComputeProgramShape(), 124 /*ignore_layouts=*/false); 125 } 126 mutable_entry_computation_layout()127 ComputationLayout* mutable_entry_computation_layout() { 128 return config_.mutable_entry_computation_layout(); 129 } 130 entry_computation_layout()131 const ComputationLayout& entry_computation_layout() const { 132 return config_.entry_computation_layout(); 133 } 134 135 // Generates a hash value of an HLO module. Hash considers 136 // information on opcode, shape, operands, and typically a root instruction. 137 // This function returns the same hash value for equivalent HLO modules, 138 // with respect to HloInstruction::Identical() method. Hash()139 uint64 Hash() const { 140 return entry_computation()->root_instruction()->Hash(); 141 } 142 143 // Gets the computations in this module. 144 // 145 // Returns a view of HloComputation*s, so you can iterate over this in the 146 // natural way: 147 // 148 // for (HloComputation* c : module->computations()) { ... } 149 // 150 tensorflow::gtl::iterator_range<UnwrappingIterator< 151 std::vector<std::unique_ptr<HloComputation>>::const_iterator>> computations()152 computations() const { 153 return {MakeUnwrappingIterator(computations_.begin()), 154 MakeUnwrappingIterator(computations_.end())}; 155 } 156 tensorflow::gtl::iterator_range<UnwrappingIterator< 157 std::vector<std::unique_ptr<HloComputation>>::iterator>> computations()158 computations() { 159 return {MakeUnwrappingIterator(computations_.begin()), 160 MakeUnwrappingIterator(computations_.end())}; 161 } 162 163 // Returns the computation in this module that has the name `name`. Returns 164 // null if there is no such computation. 165 HloComputation* GetComputationWithName(absl::string_view name); 166 167 // Gets the number of computations in this module. computation_count()168 int64 computation_count() const { return computations_.size(); } 169 170 // Returns the mutable computation for the given index. mutable_computation(int64 idx)171 HloComputation* mutable_computation(int64 idx) { 172 CHECK(idx >= 0 && idx < computations_.size()); 173 return computations_[idx].get(); 174 } 175 176 // Gets the number of instructions in this module. 177 int64 instruction_count() const; 178 179 // Compute and return a post order of all computations in the module. The sort 180 // is defined like so: if computation A has an instruction which calls 181 // computation B, then A will appear after B in the sort. 182 std::vector<HloComputation*> MakeComputationPostOrder() const; 183 184 // Gets the computations in this module which aren't for fusion nodes. 185 // 186 // Postcondition: All computations in the returned list have 187 // !IsFusionComputation(). 188 // 189 // Note: Callers can and do rely on the return value here being a *snapshot* 190 // of the module's non-fusion computations -- that is, it's OK to add or 191 // remove computations from a module while iterating over 192 // MakeNonfusionComputations(). 193 std::vector<HloComputation*> MakeNonfusionComputations() const; 194 config()195 const HloModuleConfig& config() const { return config_; } set_config(HloModuleConfig & config)196 void set_config(HloModuleConfig& config) { config_ = config; } 197 198 // Return a string representation of the module. 199 // 200 // (We express the default options using an overload rather than a default 201 // param because gdb ignores default params, but does resolve overloads.) ToString()202 string ToString() const { return ToString(HloPrintOptions()); } 203 string ToString(const HloPrintOptions& options) const; 204 205 // Convert an HloModule to or from a proto. 206 HloModuleProto ToProto() const; 207 static StatusOr<std::unique_ptr<HloModule>> CreateFromProto( 208 const HloModuleProto& proto, const HloModuleConfig& module_config); 209 210 // Creates and returns an HloModuleConfig with an appropriate program shape 211 // for the HLO module in the given proto. 212 static StatusOr<HloModuleConfig> CreateModuleConfigFromProto( 213 const HloModuleProto& module, const DebugOptions& debug_options); 214 215 // Outlines the given expression from the given computation. 216 // instructions_to_outline contains the instructions that form the expression. 217 // 218 // Precondition: instructions in instructions_to_outline are in topological 219 // order (root of outlined instructions last). TODO(jingyue): takes a set of 220 // instructions and topologically sorts them. 221 HloInstruction* OutlineExpressionFromComputation( 222 absl::Span<HloInstruction* const> instructions_to_outline, 223 const string& outlined_computation_name, HloComputation* computation); 224 225 // Returns a randomly generated uint64. 226 uint64 RandomNew64() const; 227 228 // Returns the NameUniquer for uniquing instruction names in this module. instruction_name_uniquer()229 NameUniquer& instruction_name_uniquer() { return instruction_name_uniquer_; } 230 231 // Assign a new unique dense id for an instruction NewUniqueInstructionId()232 int NewUniqueInstructionId() { 233 int result = next_unique_id_; 234 next_unique_id_++; 235 return result; 236 } 237 238 // input_output_alias_config indicates the list of aliased buffers that are 239 // expected from the module. input_output_alias_config()240 HloInputOutputAliasConfig& input_output_alias_config() { 241 return input_output_alias_config_; 242 } input_output_alias_config()243 const HloInputOutputAliasConfig& input_output_alias_config() const { 244 return input_output_alias_config_; 245 } 246 247 // DynamicParameterBinding holds the list of bindings that indicates which 248 // parameter dimensions are dynamic and which parameters represent their 249 // runtime value. dynamic_parameter_binding()250 DynamicParameterBinding& dynamic_parameter_binding() { 251 return dynamic_parameter_binding_; 252 } dynamic_parameter_binding()253 const DynamicParameterBinding& dynamic_parameter_binding() const { 254 return dynamic_parameter_binding_; 255 } 256 257 // Returns an id that is unique to this module across all modules created over 258 // the lifetime of this process. unique_id()259 int unique_id() const { return unique_id_; } 260 261 // Sets the schedule of the module to the given schedule. 262 Status set_schedule(HloSchedule schedule); 263 264 // Clears the schedule of the module. clear_schedule()265 void clear_schedule() { schedule_.reset(); } 266 267 // Returns true if the module has a schedule set. has_schedule()268 bool has_schedule() const { return schedule_.has_value(); } 269 270 // Returns the schedue of the module. CHECK fails if no schedule is set. schedule()271 const HloSchedule& schedule() const { return *schedule_; } schedule()272 HloSchedule& schedule() { return *schedule_; } 273 AddComputationAndUnifyNamesAndIds(std::unique_ptr<HloComputation> computation,bool is_entry)274 HloComputation* AddComputationAndUnifyNamesAndIds( 275 std::unique_ptr<HloComputation> computation, bool is_entry) { 276 computation->ClearUniqueIdInternal(); 277 for (auto* instruction : computation->instructions()) { 278 instruction->ClearUniqueIdInternal(); 279 } 280 return AddComputationInternal(std::move(computation), is_entry, 281 /*uniquify_identifiers=*/true); 282 } 283 284 Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const; 285 286 private: 287 HloComputation* AddComputationInternal( 288 std::unique_ptr<HloComputation> computation, bool is_entry, 289 bool uniquify_identifiers); 290 291 string name_; 292 HloModuleConfig config_; 293 HloComputation* entry_computation_ = nullptr; 294 std::vector<std::unique_ptr<HloComputation>> computations_; 295 296 // Random number generator engine to use when generating random numbers per 297 // HloModule compilation. 298 // TODO(b/25995601): Replace with better seed setting or dev/random for 299 // where we don't need deterministic execution. 300 mutable std::mt19937_64 rng_{42}; 301 mutable tensorflow::mutex rng_mutex_; 302 303 // Unique name generator for computation and instruction names, which are 304 // unique per module. 305 NameUniquer computation_name_uniquer_{/*separator=*/"."}; 306 NameUniquer instruction_name_uniquer_{/*separator=*/"."}; 307 int next_unique_id_ = 0; 308 309 // Used to keep track of the next unique module id that should be assigned. 310 static std::atomic<int> next_unique_module_id_; 311 // A unique id to label modules with. 312 int unique_id_; 313 314 // The HloSchedule of the module. The schedule if it exists contains a 315 // sequential order of instructions for each non-fusion computation in the 316 // module. 317 absl::optional<HloSchedule> schedule_; 318 319 // alias_config indicates the alias information of input/output buffers that 320 // are expected from the module. 321 HloInputOutputAliasConfig input_output_alias_config_; 322 323 // Bindings for dynamic parameter mapping. 324 DynamicParameterBinding dynamic_parameter_binding_; 325 }; 326 327 } // namespace xla 328 329 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ 330