1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 18 19 #include <iosfwd> 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <unordered_map> 25 #include <utility> 26 #include <vector> 27 28 #include "absl/container/flat_hash_map.h" 29 #include "absl/container/flat_hash_set.h" 30 #include "tensorflow/compiler/xla/layout_util.h" 31 #include "tensorflow/compiler/xla/service/call_graph.h" 32 #include "tensorflow/compiler/xla/service/computation_layout.h" 33 #include "tensorflow/compiler/xla/service/hlo_computation.h" 34 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 35 #include "tensorflow/compiler/xla/service/hlo_module.h" 36 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 37 #include "tensorflow/compiler/xla/service/logical_buffer.h" 38 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 39 #include "tensorflow/compiler/xla/shape_layout.h" 40 #include "tensorflow/compiler/xla/shape_util.h" 41 #include "tensorflow/compiler/xla/statusor.h" 42 #include "tensorflow/compiler/xla/types.h" 43 #include "tensorflow/compiler/xla/xla_data.pb.h" 44 #include "tensorflow/core/lib/core/status.h" 45 #include "tensorflow/core/platform/types.h" 46 47 namespace xla { 48 49 // Abstract base class for layout constraints. These constraint objects are 50 // gathered together in LayoutConstraints object. 51 class LayoutConstraint { 52 public: LayoutConstraint(bool mandatory,bool dfs)53 LayoutConstraint(bool mandatory, bool dfs) 54 : mandatory_(mandatory), dfs_(dfs) {} 55 virtual ~LayoutConstraint() = default; 56 57 virtual string ToString() const = 0; 58 59 // True if this constraint cannot be overwritten by a different constraint. mandatory()60 bool mandatory() const { return mandatory_; } 61 62 // When true, propagate in DFS. When false, constraint will propagate in BFS. dfs()63 bool dfs() const { return dfs_; } 64 65 private: 66 bool mandatory_; 67 bool dfs_; 68 }; 69 70 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); 71 72 // Layout constraint on a single LogicalBuffer. This constrains the layout of an 73 // array produced by a particular instruction. 74 class BufferLayoutConstraint : public LayoutConstraint { 75 public: 76 BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, 77 bool mandatory, bool dfs); 78 buffer()79 const LogicalBuffer& buffer() const { return *buffer_; } layout()80 const Layout& layout() const { return layout_; } 81 82 string ToString() const override; 83 84 private: 85 Layout layout_; 86 const LogicalBuffer* buffer_; 87 }; 88 89 // Constraint on the layout of the operand of an instruction. The constrained 90 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the 91 // use of a shaped value and is not a hard constraint on the instruction(s) 92 // which define the value as copies may be inserted between the definition and 93 // use. 94 class OperandLayoutConstraint : public LayoutConstraint { 95 public: 96 OperandLayoutConstraint(const ShapeLayout& shape_layout, 97 const HloInstruction* instruction, int64 operand_no, 98 bool mandatory, bool dfs); 99 shape_layout()100 const ShapeLayout& shape_layout() const { return shape_layout_; } instruction()101 const HloInstruction* instruction() const { return instruction_; } operand_no()102 const int64 operand_no() const { return operand_no_; } operand()103 const HloInstruction* operand() const { 104 return instruction_->operand(operand_no_); 105 } 106 107 string ToString() const override; 108 109 private: 110 ShapeLayout shape_layout_; 111 const HloInstruction* instruction_; 112 int64 operand_no_; 113 }; 114 115 // Constraint on the layout of the result of the entry computation. 116 class ResultLayoutConstraint : public LayoutConstraint { 117 public: 118 explicit ResultLayoutConstraint(const ShapeLayout& shape_layout, 119 bool dfs = false) LayoutConstraint(true,dfs)120 : LayoutConstraint(/*mandatory=*/true, dfs), 121 shape_layout_(shape_layout) {} 122 shape_layout()123 const ShapeLayout& shape_layout() const { return shape_layout_; } 124 string ToString() const override; 125 126 private: 127 const ShapeLayout shape_layout_; 128 }; 129 130 // Class encapsulating the layout constraints of the values in a HLO 131 // computation. 132 class LayoutConstraints { 133 public: 134 LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis, 135 HloComputation* computation); 136 ~LayoutConstraints() = default; 137 computation()138 const HloComputation* computation() const { return computation_; } computation()139 HloComputation* computation() { return computation_; } points_to_analysis()140 const TuplePointsToAnalysis& points_to_analysis() const { 141 return points_to_analysis_; 142 } 143 144 // Return a vector containing the constraints which have been added to the 145 // LayoutConstraints object since the construction of the object or since the 146 // last time ConsumeAddedConstraints() has been called. This is used to 147 // identify newly added constraints when propagating layouts. ConsumeAddedConstraints()148 std::vector<const LayoutConstraint*> ConsumeAddedConstraints() { 149 std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_)); 150 added_constraints_.clear(); 151 return ret_vec; 152 } ClearAddedConstraints()153 void ClearAddedConstraints() { added_constraints_.clear(); } 154 155 // Returns the layout of a LogicalBuffer, the layout of the operand of the 156 // instruction, or the layout of the result of the computation, respectively, 157 // if it has been constrained. Otherwise return nullptr. 158 const Layout* BufferLayout(const LogicalBuffer& buffer) const; 159 const BufferLayoutConstraint* GetBufferLayoutConstraint( 160 const LogicalBuffer& buffer) const; 161 const ShapeLayout* OperandLayout(const HloInstruction* instruction, 162 int64 operand_no) const; 163 const OperandLayoutConstraint* GetOperandLayoutConstraint( 164 const HloInstruction* instruction, int64 operand_no) const; 165 const ShapeLayout* ResultLayout() const; 166 167 // Add a constraint on the layout of a LogicalBuffer, the layout of the 168 // operand of the instruction, or the layout of the result of the computation, 169 // respectively. 170 Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, 171 bool mandatory = true, bool dfs = true); 172 Status SetOperandLayout(const Shape& shape_with_layout, 173 const HloInstruction* instruction, int64 operand_no, 174 bool mandatory = true, bool dfs = true); 175 Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true); 176 177 // Convenience wrapper around SetOperandLayout for setting the layout of a 178 // operand using a Layout object. The operand must be array-shaped. 179 Status SetArrayOperandLayout(const Layout& layout, 180 const HloInstruction* instruction, 181 int64 operand_no, bool mandatory = true, 182 bool dfs = true); 183 184 // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers 185 // created by the instruction to the layouts in the given shape. The 186 // instruction must define every logical buffer in its output. 187 Status SetInstructionLayout(const Shape& shape_with_layout, 188 const HloInstruction* instruction, 189 bool mandatory = true, bool dfs = true); 190 191 // Returns true if any buffer in the given operand is forwarded to the output 192 // of the given instruction. For example, the Tuple instruction forwards the 193 // buffers of its operands and would return true for each of its operands. 194 bool OperandBufferForwarded(const HloInstruction* instruction, 195 int64 operand_no) const; 196 197 // Returns the set of logical buffers (by LogicalBuffer:Id) which do not 198 // yet have a layout constraint unconstrained_buffer_ids()199 const std::set<LogicalBuffer::Id>& unconstrained_buffer_ids() const { 200 return unconstrained_buffer_ids_; 201 } 202 203 string ToString() const; 204 205 private: 206 // Find a bufferset in the bufferset cache. This is useful since we can 207 // currently create the flattened buffer set for the same instruction many 208 // times, which is often slow. 209 PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const; 210 211 // The set of BufferLayoutConstraints applied to the computation. 212 std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint> 213 buffer_constraints_; 214 215 // The set of OperandLayoutConstraints applied to the computation. 216 using OperandConstraintKey = std::pair<const HloInstruction*, int64>; 217 std::map<OperandConstraintKey, OperandLayoutConstraint> operand_constraints_; 218 219 // The result constraint for the computation (can be null). 220 std::unique_ptr<ResultLayoutConstraint> result_constraint_; 221 222 // A vector which holds constraints as they are added. Can be cleared with 223 // ClearAddedConstraints. 224 std::vector<const LayoutConstraint*> added_constraints_; 225 226 // Points-to analysis for the module. Used to propagate constraints through 227 // the HLO graph. 228 const TuplePointsToAnalysis& points_to_analysis_; 229 230 // Array-shaped buffers which have not yet been constrained. 231 std::set<LogicalBuffer::Id> unconstrained_buffer_ids_; 232 233 mutable absl::flat_hash_map<const HloInstruction*, 234 std::unique_ptr<PointsToSet::BufferSet>> 235 buffer_sets_cache_; 236 237 HloComputation* computation_; 238 }; 239 240 // Contains constraints on the layout of channels; sends and recvs. 241 class ChannelLayoutConstraints { 242 public: 243 // Construct an empty constraint set. ChannelLayoutConstraints()244 ChannelLayoutConstraints() {} 245 246 // Returns true if channel_id has a layout constraint. IsChannelConstrained(int64 channel_id)247 bool IsChannelConstrained(int64 channel_id) const { 248 return constraints_.contains(channel_id); 249 } 250 251 // Given `shape`, apply the layout for `channel_id`. `channel_id` must already 252 // be constrained. LayoutShapeForChannel(Shape shape,int64 channel_id)253 Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { 254 auto it = constraints_.find(channel_id); 255 CHECK(it != constraints_.end()) << "Channel " << channel_id; 256 *shape.mutable_layout() = it->second; 257 return shape; 258 } 259 260 // Returns the layout constraint for `channel_id`, which must already be 261 // constrained. LayoutForChannel(int64 channel_id)262 const Layout& LayoutForChannel(int64 channel_id) const { 263 auto it = constraints_.find(channel_id); 264 CHECK(it != constraints_.end()) << "Channel " << channel_id; 265 return it->second; 266 } 267 268 // Adds a new layout constraint for `channel_id`. If a constraint for 269 // `channel_id` has been added, this API returns nullptr, otherwise returns 270 // the layout which has already been set for the channel. ConstrainChannel(int64 channel_id,const Layout & layout)271 const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) { 272 auto it = constraints_.emplace(std::make_pair(channel_id, layout)); 273 if (it.second) { 274 return nullptr; 275 } 276 return LayoutUtil::Equal(layout, it.first->second) ? nullptr 277 : &it.first->second; 278 } 279 280 private: 281 absl::flat_hash_map<int64, Layout> constraints_; 282 }; 283 284 // HLO pass which assigns layouts to all instructions in the HLO module while 285 // satisfying all necessary invariants and minimizing cost. 286 class LayoutAssignment : public HloModulePass { 287 public: 288 // entry_computation_layout is modified to populate a layout for the result in 289 // the case that no particular layout is requested. 290 // 291 // instruction_can_change_layout_func is a function object that determines 292 // whether an instruction can change layouts. An instruction not being able to 293 // change layout means that it requires operands with the same rank as the 294 // output to have the same layout as the output. 295 // 296 // channel_constraints is both an input and output. Any sends or recvs that 297 // are present in channel_constraints will be laid out as constrained. Any 298 // unconstrained sends or recvs will be laid out as locally optimal and their 299 // layout will be added as a constraint to channel_constraints. 300 // 301 // If channel_constraints is nullptr, no kSend or kRecvs must be contained 302 // within any module passed to `Run`. 303 explicit LayoutAssignment( 304 ComputationLayout* entry_computation_layout, 305 std::function<bool(const HloInstruction*)> 306 instruction_can_change_layout_func = InstructionCanChangeLayout, 307 ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment()308 ~LayoutAssignment() override {} name()309 absl::string_view name() const override { return "layout-assignment"; } 310 311 // Assign layouts to the given module. Returns whether the module was changed 312 // (any layouts were changed). 313 StatusOr<bool> Run(HloModule* module) override; 314 315 // Determines whether an instruction can change layouts. An instruction not 316 // being able to change layout means that it requires operands with the same 317 // rank as the output to have the same layout as the output. 318 static bool InstructionCanChangeLayout(const HloInstruction* instruction); 319 320 // In case of an array shape returns true iff it is at most rank 1. In case of 321 // a tuple shape returns true iff all leaf shapes are at most rank 1. 322 static bool IsAtMostRank1(const Shape& shape); 323 324 protected: 325 // These methods, invoked by PropagateConstraints, propagate a layout 326 // constraint to its neighbors (i.e. operands and users) in order to minimize 327 // the cost of the instructions being constrainted on. New constraints are 328 // added to the given constraint set. 329 // 330 // Backends can override these methods with backend-specific propagation 331 // rules. 332 virtual Status PropagateBufferConstraint( 333 const BufferLayoutConstraint& layout_constraint, 334 LayoutConstraints* constraints); 335 virtual Status PropagateOperandConstraint( 336 const OperandLayoutConstraint& layout_constraint, 337 LayoutConstraints* constraints); 338 virtual Status PropagateResultConstraint( 339 const ResultLayoutConstraint& layout_constraint, 340 LayoutConstraints* constraints); 341 GetUnconstrainedLayout(const LogicalBuffer & buffer)342 virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) { 343 return LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); 344 } 345 // Called after layouts of an instruction have been finalized to allow 346 // subclasses to check for platform specific assumptions. Verify(const HloInstruction * instruction)347 virtual Status Verify(const HloInstruction* instruction) { 348 return Status::OK(); 349 } 350 351 // Propagates a buffer layout constraint into the operands that use it. 352 Status PropagateBufferConstraintToUses( 353 const BufferLayoutConstraint& layout_constraint, 354 LayoutConstraints* constraints); 355 356 // Propagates a layout constraint on the use of the result of the given 357 // instruction to the definitions of the LogicalBuffers which make up the 358 // result. 359 Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout, 360 const HloInstruction* instruction, 361 LayoutConstraints* constraints); 362 363 // Propagates the memory space defined in the entry computation to the called 364 // computations. 365 Status PropagateMemorySpace(HloModule* module); 366 367 // Chooses a layout of operand `operand_no` of `instruction` that minimizes 368 // the cost of `instruction`. `output_layout` is the layout of `instruction`. 369 // Returns null if it can't decide the best layout. 370 // Precondition: `instruction` and the operand are array-shaped. 371 virtual std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout( 372 const Layout& output_layout, const HloInstruction* instruction, 373 int64 operand_no); 374 // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of 375 // `user` that minimizes its cost on that operand. Returns null if it can't 376 // decide the best layout. 377 // Precondition: `user` and the operand are array-shaped. 378 virtual std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout( 379 const Layout& operand_layout, const HloInstruction* user, 380 int64 operand_no); 381 382 private: 383 // Initializes the layout assignment object for a new Run() call. 384 Status Init(); 385 386 // Adds constraints which must be satisfied for correctness on all 387 // backends. Called once prior to propagating constraints. 388 Status AddMandatoryConstraints(const ComputationLayout* computation_layout, 389 ChannelLayoutConstraints* channel_constraints, 390 HloComputation* computation, 391 LayoutConstraints* constraints); 392 393 // This method can be overridden to add backend-specific constraints to the 394 // layout of the instructions of a computation. This method is called after 395 // all mandatory constraints have been added via AddMandatoryConstraints 396 // and before propagating constraints. AddBackendConstraints(LayoutConstraints * constraints)397 virtual Status AddBackendConstraints(LayoutConstraints* constraints) { 398 return Status::OK(); 399 } 400 401 // Construct constraints and assign layouts to all instructions in the 402 // computation satisfying the given ComputationLayout, if not nullptr. 403 // Otherwise the ComputationLayout will be calculated by propagating the 404 // computation instruction constraints. 405 // Layouts constraints are added, then propagated until all LogicalBuffers in 406 // the computation are constrained. 407 Status RunOnComputation(ComputationLayout* computation_layout, 408 HloComputation* computation, 409 ChannelLayoutConstraints* channel_constraints); 410 411 // Assign layouts to the instructions of a computation which satisfy the given 412 // layout constraints. Copies may be added to satisfy the constraints. The 413 // given LayoutConstraints must have layout constraints every logical buffer 414 // in the computation. 415 Status AssignLayouts(const LayoutConstraints& constraints, 416 HloComputation* computation); 417 418 // Propagates layout constraints from a set of initial constraints in order to 419 // minimize the local cost of the computation. This propagation is *not* 420 // required for correctness. 421 Status PropagateConstraints(LayoutConstraints* constraints); 422 423 Status PropagateBufferConstraintToOperands( 424 const BufferLayoutConstraint& buffer_constraint, 425 LayoutConstraints* constraints); 426 427 // Check that all layouts in the module have been set and satisfy all 428 // necessary conditions. 429 Status CheckLayouts(HloModule* module); 430 431 // Computes the ComputationLayout of the given computation based of the 432 // layouts assigned to parameters and root instruction, and inserts it to the 433 // computation_layouts_ map. 434 Status CalculateComputationLayout(HloComputation* computation); 435 436 // Clears all the layouts which can be cleared within a computation. 437 Status ClearComputationLayouts(HloComputation* computation); 438 439 // Clears the side effects of a previous pass, like added copy instructions. 440 Status ClearPreviousPassSideEffects(HloModule* module); 441 442 // Propagates the layouts computed by the layout assignment pass on the given 443 // computation, to the computation layout passed in to this API. 444 // This API propagates missing layout, and also checks that the caller 445 // specified have been respected, by comparing those with the parameters and 446 // root computation instruction. 447 Status PropagateComputationLayouts(HloComputation* computation, 448 ComputationLayout* computation_layout); 449 450 // The pointer to the ComputationLayout passed as constructor parameter. 451 ComputationLayout* entry_computation_layout_; 452 453 // A copy of entry_computation_layout_ used to reset it to the initial values 454 // during the multiple passes done by the layout assignment operation. 455 ComputationLayout saved_entry_computation_layout_; 456 457 protected: 458 // Sets up the copy instruction according to the characteristic (sharding, 459 // metadata, ...) of the reference instruction. The index argument is used 460 // when the instruction is a tuple, and in such case the index represents 461 // the location from where the copy instruction was created from. 462 // If the index is empty, the whole sharding will be propagated, even in case 463 // the instruction has a tuple sharding. 464 static void SetupCopiedInstruction(const HloInstruction& instruction, 465 HloInstruction* copy, 466 const ShapeIndex& index); 467 468 // Creates and returns a copy of the given instruction with a different 469 // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple 470 // instruction producing the copy is returned. 471 StatusOr<HloInstruction*> CreateCopyWithNewLayout( 472 const Shape& shape_with_layout, HloInstruction* instruction); 473 474 // Creates a copy of the given operand if the operand's layout does not match 475 // the given layout. This copy replaces the use in the given instruction. 476 // Tuple operands will be deep-copied. 477 virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, 478 HloInstruction* instruction, 479 int64 operand_no); 480 481 // Registers a copy instruction added by the layout assignment pass. RegisterAddedCopy(HloInstruction * copy)482 void RegisterAddedCopy(HloInstruction* copy) { 483 CHECK_EQ(copy->opcode(), HloOpcode::kCopy); 484 added_copies_.insert(copy); 485 } 486 487 // Adds a copy for the operand of an instruction, unless such operand is 488 // already a copy, and has a single user (which is forcibly the instruction 489 // itself). 490 Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); 491 492 // Apply the channel layout constraints by populating the channel_constraints 493 // data structure passed in at constructor time. Eventually adds copies in 494 // case two ends of a channel ended up with a different leyout. 495 Status ConstrainChannelLayouts(HloComputation* computation, 496 ChannelLayoutConstraints* channel_constraints); 497 498 // Resets the input ChannelLayoutConstraints to the original copy received 499 // from the constructor input. ResetChannelConstraints()500 void ResetChannelConstraints() { 501 if (channel_layout_constraints_ != nullptr) { 502 *channel_layout_constraints_ = channel_constraints_; 503 } 504 } 505 506 // Adds constraints related to host Send/Recv instructions. 507 Status BuildHostChannelConstraints(HloComputation* computation); 508 509 // Map containing the layouts of all computations assigned so 510 // far. Computations are handled in a topological sort where computations are 511 // handled before their caller instructions so the layouts of caller 512 // instructions can be set to match the computation. 513 std::map<HloComputation*, ComputationLayout> computation_layouts_; 514 515 // Map from branch computations to the result layout they should apply. 516 std::map<HloComputation*, ComputationLayout> conditional_mismatch_; 517 518 // Every copy added to the module by the layout assignment pass is registered 519 // here. 520 absl::flat_hash_set<HloInstruction*> added_copies_; 521 522 // The pointer to the channel layout constraints passed in with the 523 // constructor. If not nullptr, this is an input/output argument. 524 ChannelLayoutConstraints* channel_layout_constraints_ = nullptr; 525 526 // A copy of the input layout constraints used to reset the above pointer in 527 // case we have to undo operations due to the multiple passes over the 528 // computations/instructions. 529 ChannelLayoutConstraints channel_constraints_; 530 531 // Layout constraints for send/recv instructions which communicate with the 532 // host. 533 ChannelLayoutConstraints host_channel_constraints_; 534 535 // Module points to analysis that can be updated for cloned computations. 536 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; 537 538 // The set of HLO instructions which lacked any layout constraint, thus 539 // receiving propagated default layouts. 540 absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_; 541 542 std::function<bool(const HloInstruction*)> 543 instruction_can_change_layout_func_; 544 545 // CallGraph of the module, used to track callsites of each computation. 546 std::unique_ptr<CallGraph> call_graph_; 547 }; 548 549 } // namespace xla 550 551 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_ 552