1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 18 19 #include <functional> 20 #include <list> 21 #include <memory> 22 #include <string> 23 #include <unordered_set> 24 #include <utility> 25 #include <vector> 26 27 #include "absl/container/flat_hash_map.h" 28 #include "absl/container/flat_hash_set.h" 29 #include "absl/types/span.h" 30 #include "tensorflow/compiler/xla/iterator_util.h" 31 #include "tensorflow/compiler/xla/map_util.h" 32 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" 33 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 34 #include "tensorflow/compiler/xla/service/hlo.pb.h" 35 #include "tensorflow/compiler/xla/service/hlo_clone_context.h" 36 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 37 #include "tensorflow/compiler/xla/service/name_uniquer.h" 38 #include "tensorflow/compiler/xla/shape_tree.h" 39 #include "tensorflow/compiler/xla/statusor.h" 40 #include "tensorflow/compiler/xla/types.h" 41 #include "tensorflow/compiler/xla/xla_data.pb.h" 42 #include "tensorflow/core/lib/core/status.h" 43 #include "tensorflow/core/platform/macros.h" 44 #include "tensorflow/core/platform/types.h" 45 46 namespace xla { 47 48 class HloModule; 49 50 // Describes a computation at the HLO level. 51 // 52 // You can think of an HloComputation like a function. It has some inputs 53 // (parameters) and returns exactly one value (the value of its root node). If 54 // you want to return multiple values, you can return a tuple. 55 // 56 // The instructions inside of a computation do not have an explicit total order. 57 // Instead, they have a partial order determined by their data and control 58 // dependencies. 59 // 60 // An HloModule contains one "entry computation" -- this is like main() in a C 61 // program. Every other computation inside of a module is attached to one or 62 // more HloInstructions, as a "nested computation". For example, the kMap 63 // instruction has a nested computation and "applies" it to every element of its 64 // input, elementwise. (That is, the input [x, y, z] is transformed to [f(x), 65 // f(y), f(z)].) 66 class HloComputation { 67 public: 68 // Builder class for HloComputation. 69 class Builder { 70 public: 71 explicit Builder(const string& name, 72 HloInstruction* fusion_instruction = nullptr) name_(name)73 : name_(name), 74 last_added_instruction_(nullptr), 75 fusion_instruction_(fusion_instruction) {} 76 77 // Build and return an HloComputation. The parameter root_instruction 78 // specifies the already-added instruction to use as the root. If 79 // root_instruction is nullptr then use the last added instruction as the 80 // root. 81 std::unique_ptr<HloComputation> Build( 82 HloInstruction* root_instruction = nullptr); 83 AddInstruction(std::unique_ptr<HloInstruction> instruction)84 HloInstruction* AddInstruction( 85 std::unique_ptr<HloInstruction> instruction) { 86 instructions_.push_back(std::move(instruction)); 87 last_added_instruction_ = instructions_.back().get(); 88 return last_added_instruction_; 89 } 90 ForEachInstruction(const std::function<Status (const HloInstruction *)> & func)91 Status ForEachInstruction( 92 const std::function<Status(const HloInstruction*)>& func) const { 93 for (const auto& instruction : instructions_) { 94 TF_RETURN_IF_ERROR(func(instruction.get())); 95 } 96 return Status::OK(); 97 } 98 99 private: 100 const string name_; 101 HloInstruction* last_added_instruction_; 102 HloInstruction* fusion_instruction_; 103 std::vector<std::unique_ptr<HloInstruction>> instructions_; 104 }; 105 106 // Add an instruction to the computation. The computation takes ownership of 107 // the instruction. 108 HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction); 109 110 // Remove the param_no'th parameter from the computation. 111 // Note this is only applicatable to the computation for the fusion 112 // instruction. 113 Status RemoveParameter(int64 param_no); 114 115 // Remove unused parameters from the computation. 116 // Note this is only applicatable to the computation for the fusion 117 // instruction. 118 Status RemoveUnusedParameters(); 119 120 // Adds a new parameter instruction to a fusion computation. 121 // 122 // This should be a new parameter. Instruction will be appended to parameters 123 // and inserted to the instruction list. 124 HloInstruction* AddParameter(std::unique_ptr<HloInstruction> instruction); 125 126 // Adds a new parameter instruction to the entry computation and update 127 // the parent module config to reflect the change. 128 // 129 // This should be a new parameter. Instruction will be appended to parameters 130 // and inserted to the instruction list. 131 HloInstruction* AddEntryComputationParameter( 132 std::unique_ptr<HloInstruction> instruction); 133 134 // Remove an instruction from the computation. The instruction must have no 135 // users. Instruction is deallocated with this call. 136 Status RemoveInstruction(HloInstruction* instruction); 137 138 // Remove an instruction (including side effecting ones) from the computation 139 // and also transitively any operand that has no side effect and no users post 140 // removing an instruction. The instruction must have no users. Instruction is 141 // deallocated with this call. 142 Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); 143 144 // Set the root of the computation to the given instruction. The instruction 145 // must have already been added to the computation. In addition it must have 146 // the same shape as the result of the computation for non fusion 147 // computations, except if accept_different_shape is set to true. 148 void set_root_instruction(HloInstruction* new_root_instruction, 149 bool accept_different_shape = false); 150 151 // Return the root instruction of the computation. The root instruction is the 152 // instruction which produces the output of the computation. root_instruction()153 HloInstruction* root_instruction() const { return root_instruction_; } 154 155 // Returns the number of parameters for this computation. num_parameters()156 int64 num_parameters() const { return param_instructions_.size(); } 157 158 // Returns the parameter instruction for the given parameter number. parameter_instruction(int64 param_no)159 HloInstruction* parameter_instruction(int64 param_no) const { 160 CHECK_GE(param_no, 0); 161 CHECK_LT(param_no, static_cast<int64>(param_instructions_.size())) 162 << "Computation " << name() << " has no parameter number " << param_no; 163 return param_instructions_[param_no]; 164 } 165 parameter_instructions()166 const std::vector<HloInstruction*>& parameter_instructions() const { 167 return param_instructions_; 168 } 169 name()170 const string& name() const { return name_; } 171 172 // Use the given NameUniquer to select a unique name for the computation based 173 // on the computation's existing name. 174 void UniquifyName(NameUniquer* name_uniquer); 175 176 // Return a string representation of the computation. 177 // 178 // (We express the default options using an overload rather than a default 179 // param because gdb ignores default params, but does resolve overloads.) ToString()180 string ToString() const { return ToString(HloPrintOptions()); } 181 string ToString(const HloPrintOptions& options) const; 182 183 // Overload which accepts an order to emit the instructions in. 184 string ToString( 185 const HloPrintOptions& options, 186 absl::Span<const HloInstruction* const> instruction_order) const; 187 188 // Returns a serialized representation of this computation. 189 HloComputationProto ToProto() const; 190 191 // Creates a computation from the given proto. Arguments: 192 // 193 // proto: the proto to convert from. 194 // computation_map: a map from computation id to HloComputation*. This map 195 // must contain all computations which the newly constructed computation 196 // calls. 197 static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto( 198 const HloComputationProto& proto, 199 const absl::flat_hash_map<int64, HloComputation*>& computation_map); 200 201 // Gets the instructions in this computation. 202 // 203 // The returned type is a range of HloInstruction*s, so you can iterate over 204 // it using a range-based for loop in the natural way: 205 // 206 // for (HloInstruction* instr : computation->instructions()) { ... } 207 // 208 tensorflow::gtl::iterator_range<UnwrappingIterator< 209 std::list<std::unique_ptr<HloInstruction>>::const_iterator>> instructions()210 instructions() const { 211 return {MakeUnwrappingIterator(instructions_.begin()), 212 MakeUnwrappingIterator(instructions_.end())}; 213 } 214 tensorflow::gtl::iterator_range< 215 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>> instructions()216 instructions() { 217 return {MakeUnwrappingIterator(instructions_.begin()), 218 MakeUnwrappingIterator(instructions_.end())}; 219 } 220 221 // Compute and return a post-order of the instructions in the computation. In 222 // this order, definitions of values always appear before their uses. 223 std::vector<HloInstruction*> MakeInstructionPostOrder() const; 224 instruction_count()225 int64 instruction_count() const { return instruction_iterators_.size(); } 226 227 // Creates and returns a list of the embedded computations called by this 228 // computation. This includes all embedded computations called directly or 229 // transitively. The embedded computations are sorted such that if computation 230 // A calls computation B (eg, via a map instruction) then A will appear after 231 // B in the list. 232 std::vector<HloComputation*> MakeEmbeddedComputationsList() const; 233 234 // Creates a fusion instruction containing the given instructions. 235 // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion 236 // into a library call. Instructions must be in reverse topological order 237 // (root of the fused expression first). Replaces all uses of the original 238 // root instruction with the fusion instruction. The original instructions are 239 // removed if they have no uses after fusion (this is necessarily true for at 240 // least the root). 241 HloInstruction* CreateFusionInstruction( 242 absl::Span<HloInstruction* const> instructions_to_fuse, 243 HloInstruction::FusionKind fusion_kind); 244 245 // Create a deep copy of the given instruction and return the instruction 246 // producing the copied result. All instructions performing the copy are added 247 // to the computation. For array-shaped values, this method trivially returns 248 // a kCopy instruction. For tuple-shaped instructions, the copy is performed 249 // with a series of kGetTupleElement and kTuple instructions. If 250 // indices_to_copy is non-null then this ShapeTree indicates which elements 251 // (arrays) of the shape to copy. Non-copied elements are passed through 252 // transparently. If copies_added is non-null, then the added kCopy 253 // instructions will be inserted in the respective index in the given 254 // ShapeTree. 255 StatusOr<HloInstruction*> DeepCopyInstruction( 256 HloInstruction* instruction, 257 const ShapeTree<bool>* indices_to_copy = nullptr, 258 ShapeTree<HloInstruction*>* copies_added = nullptr); 259 260 // As above, but uses a custom function to copy the leaf nodes, which could 261 // create alternative HLOs other than kCopy, or even pass-throughs. 262 StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier( 263 HloInstruction* instruction, 264 const std::function< 265 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, 266 HloComputation* computation)>& copy_leaf); 267 268 // Computes and returns the ProgramShape of this computation (shape of 269 // parameters and result with layout). 270 ProgramShape ComputeProgramShape() const; 271 272 // Return whether `*this` and `other` are functionally equivalent. 273 bool operator==(const HloComputation& other) const; 274 275 // Replaces old instruction with newly created instruction. Removes old 276 // instruction from computation. Updates uses and root instruction. 277 Status ReplaceWithNewInstruction( 278 HloInstruction* old_instruction, 279 std::unique_ptr<HloInstruction> new_instruction); 280 281 // Replace old instruction with new instruction. Updates uses and root 282 // instruction. Removes old instruction from computation. Precondition: 283 // old_instruction and new_instruction must have the compatible shapes. 284 Status ReplaceInstruction(HloInstruction* old_instruction, 285 HloInstruction* new_instruction); 286 287 // Set/get the module containing this computation. set_parent(HloModule * module)288 void set_parent(HloModule* module) { parent_ = module; } parent()289 const HloModule* parent() const { return parent_; } parent()290 HloModule* parent() { return parent_; } 291 292 // Visit every node in the computation in DFS post-order with the given 293 // visitor. This is similar to calling HloInstruction::Accept on the root of 294 // the computation except this method also visits instructions not reachable 295 // via the root. The root instruction of the computation is visited last, and 296 // the visitor's FinishVisit method is called once upon completion (with the 297 // root instruction as the argument). 298 template <typename HloInstructionPtr> 299 Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor) const; 300 301 // Same as Accept() above, but the order of operand and control predecessor 302 // visitation is determined by the given operand order; if compare(A, B) == 303 // true, A is visited before B. 304 Status AcceptWithOperandOrder( 305 DfsHloVisitor* visitor, 306 const HloInstruction::CompareFunction& operand_order) const; 307 308 // Visit every node in the computation in the given order. 'order' must 309 // be a topological sort of all instructions in the computation. 310 template <typename HloInstructionPtr> 311 Status AcceptOrdered(DfsHloVisitorBase<HloInstructionPtr>* visitor, 312 absl::Span<HloInstruction* const> order) const; 313 314 // Same as Accept() above, but the visitor is given as a function. 315 Status Accept(const std::function<Status(HloInstruction*)>& visitor_func); 316 Status Accept( 317 const std::function<Status(const HloInstruction*)>& visitor_func) const; 318 319 // Returns a deep copy of this computation including all instructions. 320 // If the clone context is specified, it will be populated with the cloned 321 // object mappings, and its module() will be used to add new computations 322 // into. 323 std::unique_ptr<HloComputation> Clone(const string& suffix = "clone", 324 HloCloneContext* context = nullptr); 325 326 // Like Clone(), but if an instruction is present in replacement_map, we use 327 // the map's value to replace that instruction in the cloned computation. 328 // 329 // If replacements maps a key to nullptr, we remove that instruction from the 330 // new computation. If an element of `replacements` references an instruction 331 // that's not already in the computation, it's cloned and added to the new 332 // computation. 333 // 334 // 'extra_parameters' allows to specify additional parameters that should be 335 // added to the computation. 336 // 337 // All relevant instructions are cloned, *including* unique_ptr in the 338 // `replacements` map. 339 std::unique_ptr<HloComputation> CloneWithReplacements( 340 absl::flat_hash_map<const HloInstruction*, 341 std::unique_ptr<HloInstruction>> 342 replacements, 343 absl::Span<const HloInstruction* const> extra_parameters = {}, 344 HloCloneContext* context = nullptr, const string& suffix = "clone"); 345 346 // Convenience overloads for CloneWithReplacements. You want to do 347 // 348 // CloneWithReplacements({{a, std::move(b)}, {c, std::move(d)}}) // ERROR 349 // 350 // but that doesn't work because std::initializer_list is not movable. These 351 // overloads let you do 352 // 353 // CloneWithReplacementPairs({a, std::move(b)}, {c, std::move(d)}); // OK 354 // 355 std::unique_ptr<HloComputation> CloneWithReplacementPairs( 356 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1, 357 HloCloneContext* context = nullptr, const string& suffix = "clone"); 358 std::unique_ptr<HloComputation> CloneWithReplacementPairs( 359 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1, 360 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2, 361 HloCloneContext* context = nullptr, const string& suffix = "clone"); 362 std::unique_ptr<HloComputation> CloneWithReplacementPairs( 363 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1, 364 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2, 365 std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3, 366 HloCloneContext* context = nullptr, const string& suffix = "clone"); 367 368 // Returns true if the given instruction can be removed from the computation. 369 // Parameter instructions cannot be removed without violating invariants of 370 // the HLO computation with the exception of fusion computation. A parameter 371 // instruction is removable for a fusion computation. 372 // 373 // Note that IsRemovable() is a necessariy condition to remove an instruction 374 // rather than a sufficient condition. For example, instructions with 375 // side-effect (e.g., Send, Infeed) may be removed from a computation, but the 376 // transformation must guarantee the invariants relevant to the instructions 377 // still hold (e.g., Send and Recv must be removed together to make each 378 // channel complete). 379 bool IsRemovable(const HloInstruction* instruction); 380 381 // Returns a map from channel-id to the group of instructions associated with 382 // the channel. These instructions will be considered as a single node for 383 // dependency purposes. Send and RecvDone are in the group, and AllReduces 384 // with the same channel id are in the group. 385 using ChannelDependencyGroup = 386 absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>; 387 ChannelDependencyGroup ComputeChannelDependencies() const; 388 389 // Returns true if this computation has a side effect. A computation has a 390 // side effect if it contains one or more instructions with a side effect. 391 bool HasSideEffect() const; 392 393 // Returns if this computation is a fusion computation. IsFusionComputation()394 bool IsFusionComputation() const { return fusion_instruction_ != nullptr; } 395 396 // Returns the owning fusion instruction, or nullptr if this is not a fusion 397 // computation. FusionInstruction()398 HloInstruction* FusionInstruction() const { return fusion_instruction_; } SetFusionInstruction(HloInstruction * fusion_instruction)399 void SetFusionInstruction(HloInstruction* fusion_instruction) { 400 fusion_instruction_ = fusion_instruction; 401 } 402 403 // Clear the unique ID of the computation so that it can be re-assigned, such 404 // as for the purpose of compacting the unique IDs. ClearUniqueIdInternal()405 void ClearUniqueIdInternal() { unique_id_ = -1; } 406 407 // The id of this computation should be unique within the module. SetUniqueId(int64 id)408 void SetUniqueId(int64 id) { 409 CHECK_EQ(unique_id_, -1); 410 CHECK_GE(id, 0); 411 unique_id_ = id; 412 } 413 414 // Returns the instruction in this computation that has name `name`. Returns 415 // null if there is no such computation. 416 HloInstruction* GetInstructionWithName(absl::string_view name); 417 unique_id()418 int64 unique_id() const { return unique_id_; } 419 420 private: 421 explicit HloComputation( 422 const string& name, int parameter_count, 423 std::vector<std::unique_ptr<HloInstruction>>* instructions, 424 HloInstruction* root_instruction, HloInstruction* fusion_instruction); 425 426 // Internal helper for adding instructions. 427 HloInstruction* AddInstructionInternal( 428 std::unique_ptr<HloInstruction> instruction); 429 430 // Fuses HLOs in instructions_to_fuse into fusion_instruction. 431 // 432 // Pre-condition: fusion_instruction's opcode is kFusion. 433 void FuseInstructionsInto( 434 absl::Span<HloInstruction* const> instructions_to_fuse, 435 HloInstruction* fusion_instruction); 436 437 // Internal helper for recursive copying of an instruction. Creates and 438 // returns a deep copy of the given instruction. 439 StatusOr<HloInstruction*> DeepCopyHelper( 440 HloInstruction* instruction, ShapeIndex* index, 441 const std::function< 442 HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index, 443 HloComputation* computation)>& copy_leaf); 444 445 // Internal helper to collect unreachable roots. 446 std::vector<HloInstruction*> CollectUnreachableRoots() const; 447 448 enum VisitState { kVisiting, kVisited }; 449 void ComputeInstructionPostOrder( 450 const HloComputation::ChannelDependencyGroup& channel_dependency_map, 451 std::vector<HloInstruction*>* post_order, HloInstruction* root, 452 absl::flat_hash_map<HloInstruction*, VisitState>* visited) const; 453 454 string name_; 455 int64 unique_id_; 456 HloInstruction* root_instruction_; 457 458 // If this computation is a fusion computation, this field points to the 459 // corresponding fusion instruction. Otherwise, this is null. 460 HloInstruction* fusion_instruction_; 461 462 // Module containing this computation. 463 HloModule* parent_ = nullptr; 464 465 // Store instructions in std::list as they can be added and removed 466 // arbitrarily and we want a stable iteration order. Keep a map from 467 // instruction pointer to location in the list for fast lookup. 468 using InstructionList = std::list<std::unique_ptr<HloInstruction>>; 469 InstructionList instructions_; 470 absl::flat_hash_map<const HloInstruction*, InstructionList::iterator> 471 instruction_iterators_; 472 473 std::vector<HloInstruction*> param_instructions_; 474 475 TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); 476 }; 477 478 } // namespace xla 479 480 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ 481