• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
18 
19 #include <iosfwd>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "tensorflow/compiler/xla/service/computation_layout.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
35 #include "tensorflow/compiler/xla/service/logical_buffer.h"
36 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
37 #include "tensorflow/compiler/xla/shape_layout.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 // Abstract base class for layout constraints. These constraint objects are
48 // gathered together in LayoutConstraints object.
49 class LayoutConstraint {
50  public:
LayoutConstraint(bool mandatory,bool dfs)51   LayoutConstraint(bool mandatory, bool dfs)
52       : mandatory_(mandatory), dfs_(dfs) {}
53   virtual ~LayoutConstraint() = default;
54 
55   virtual string ToString() const = 0;
56 
57   // True if this constraint cannot be overwritten by a different constraint.
mandatory()58   bool mandatory() const { return mandatory_; }
59 
60   // When true, propagate in DFS. When false, constraint will propagate in BFS.
dfs()61   bool dfs() const { return dfs_; }
62 
63  private:
64   bool mandatory_;
65   bool dfs_;
66 };
67 
68 std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
69 
70 // Layout constraint on a single LogicalBuffer. This constrains the layout of an
71 // array produced by a particular instruction.
72 class BufferLayoutConstraint : public LayoutConstraint {
73  public:
74   BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer,
75                          bool mandatory, bool dfs);
76 
buffer()77   const LogicalBuffer& buffer() const { return *buffer_; }
layout()78   const Layout& layout() const { return layout_; }
79 
80   string ToString() const override;
81 
82  private:
83   Layout layout_;
84   const LogicalBuffer* buffer_;
85 };
86 
87 // Constraint on the layout of the operand of an instruction. The constrained
88 // shape can be arbitrarily shaped (array or tuple). This is a constraint on the
89 // use of a shaped value and is not a hard constraint on the instruction(s)
90 // which define the value as copies may be inserted between the definition and
91 // use.
92 class OperandLayoutConstraint : public LayoutConstraint {
93  public:
94   OperandLayoutConstraint(const ShapeLayout& shape_layout,
95                           const HloInstruction* instruction, int64 operand_no,
96                           bool mandatory, bool dfs);
97 
shape_layout()98   const ShapeLayout& shape_layout() const { return shape_layout_; }
instruction()99   const HloInstruction* instruction() const { return instruction_; }
operand_no()100   const int64 operand_no() const { return operand_no_; }
operand()101   const HloInstruction* operand() const {
102     return instruction_->operand(operand_no_);
103   }
104 
105   string ToString() const override;
106 
107  private:
108   ShapeLayout shape_layout_;
109   const HloInstruction* instruction_;
110   int64 operand_no_;
111 };
112 
113 // Constraint on the layout of the result of the entry computation.
114 class ResultLayoutConstraint : public LayoutConstraint {
115  public:
116   explicit ResultLayoutConstraint(const ShapeLayout& shape_layout,
117                                   bool dfs = false)
LayoutConstraint(true,dfs)118       : LayoutConstraint(/*mandatory=*/true, dfs),
119         shape_layout_(shape_layout) {}
120 
shape_layout()121   const ShapeLayout& shape_layout() const { return shape_layout_; }
122   string ToString() const override;
123 
124  private:
125   const ShapeLayout shape_layout_;
126 };
127 
128 // Class encapsulating the layout constraints of the values in a HLO
129 // computation.
130 class LayoutConstraints {
131  public:
132   LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis,
133                     HloComputation* computation);
134   ~LayoutConstraints() = default;
135 
computation()136   const HloComputation* computation() const { return computation_; }
computation()137   HloComputation* computation() { return computation_; }
points_to_analysis()138   const TuplePointsToAnalysis& points_to_analysis() const {
139     return points_to_analysis_;
140   }
141 
142   // Return a vector containing the constraints which have been added to the
143   // LayoutConstraints object since the construction of the object or since the
144   // last time ConsumeAddedConstraints() has been called. This is used to
145   // identify newly added constraints when propagating layouts.
ConsumeAddedConstraints()146   std::vector<const LayoutConstraint*> ConsumeAddedConstraints() {
147     std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_));
148     added_constraints_.clear();
149     return ret_vec;
150   }
ClearAddedConstraints()151   void ClearAddedConstraints() { added_constraints_.clear(); }
152 
153   // Returns the layout of a LogicalBuffer, the layout of the operand of the
154   // instruction, or the layout of the result of the computation, respectively,
155   // if it has been constrained. Otherwise return nullptr.
156   const Layout* BufferLayout(const LogicalBuffer& buffer) const;
157   const BufferLayoutConstraint* GetBufferLayoutConstraint(
158       const LogicalBuffer& buffer) const;
159   const ShapeLayout* OperandLayout(const HloInstruction* instruction,
160                                    int64 operand_no) const;
161   const OperandLayoutConstraint* GetOperandLayoutConstraint(
162       const HloInstruction* instruction, int64 operand_no) const;
163   const ShapeLayout* ResultLayout() const;
164 
165   // Add a constraint on the layout of a LogicalBuffer, the layout of the
166   // operand of the instruction, or the layout of the result of the computation,
167   // respectively.
168   Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer,
169                          bool mandatory = true, bool dfs = true);
170   Status SetOperandLayout(const Shape& shape_with_layout,
171                           const HloInstruction* instruction, int64 operand_no,
172                           bool mandatory = true, bool dfs = true);
173   Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true);
174 
175   // Convenience wrapper around SetOperandLayout for setting the layout of a
176   // operand using a Layout object. The operand must be array-shaped.
177   Status SetArrayOperandLayout(const Layout& layout,
178                                const HloInstruction* instruction,
179                                int64 operand_no, bool mandatory = true,
180                                bool dfs = true);
181 
182   // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
183   // created by the instruction to the layouts in the given shape. The
184   // instruction must define every logical buffer in its output.
185   Status SetInstructionLayout(const Shape& shape_with_layout,
186                               const HloInstruction* instruction,
187                               bool mandatory = true, bool dfs = true);
188 
189   // Returns true if any buffer in the given operand is forwarded to the output
190   // of the given instruction. For example, the Tuple instruction forwards the
191   // buffers of its operands and would return true for each of its operands.
192   bool OperandBufferForwarded(const HloInstruction* instruction,
193                               int64 operand_no) const;
194 
195   // Returns the set of logical buffers (by LogicalBuffer:Id) which do not
196   // yet have a layout constraint
unconstrained_buffer_ids()197   const std::set<LogicalBuffer::Id>& unconstrained_buffer_ids() const {
198     return unconstrained_buffer_ids_;
199   }
200 
201   string ToString() const;
202 
203  private:
204   // Find a bufferset in the bufferset cache. This is useful since we can
205   // currently create the flattened buffer set for the same instruction many
206   // times, which is often slow.
207   PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const;
208 
209   // The set of BufferLayoutConstraints applied to the computation.
210   std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint>
211       buffer_constraints_;
212 
213   // The set of OperandLayoutConstraints applied to the computation.
214   using OperandConstraintKey = std::pair<const HloInstruction*, int64>;
215   std::map<OperandConstraintKey, OperandLayoutConstraint> operand_constraints_;
216 
217   // The result constraint for the computation (can be null).
218   std::unique_ptr<ResultLayoutConstraint> result_constraint_;
219 
220   // A vector which holds constraints as they are added. Can be cleared with
221   // ClearAddedConstraints.
222   std::vector<const LayoutConstraint*> added_constraints_;
223 
224   // Points-to analysis for the module. Used to propagate constraints through
225   // the HLO graph.
226   const TuplePointsToAnalysis& points_to_analysis_;
227 
228   // Array-shaped buffers which have not yet been constrained.
229   std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
230 
231   mutable absl::flat_hash_map<const HloInstruction*,
232                               std::unique_ptr<PointsToSet::BufferSet>>
233       buffer_sets_cache_;
234 
235   HloComputation* computation_;
236 };
237 
238 // Contains constraints on the layout of channels; sends and recvs.
239 class ChannelLayoutConstraints {
240  public:
241   // Construct an empty constraint set.
ChannelLayoutConstraints()242   ChannelLayoutConstraints() {}
243 
244   // Returns true if channel_id has a layout constraint.
IsChannelConstrained(int64 channel_id)245   bool IsChannelConstrained(int64 channel_id) const {
246     return constraints_.contains(channel_id);
247   }
248 
249   // Given `shape`, apply the layout for `channel_id`. `channel_id` must already
250   // be constrained.
LayoutShapeForChannel(Shape shape,int64 channel_id)251   Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const {
252     auto it = constraints_.find(channel_id);
253     CHECK(it != constraints_.end()) << "Channel " << channel_id;
254     *shape.mutable_layout() = it->second;
255     return shape;
256   }
257 
258   // Returns the layout constraint for `channel_id`, which must already be
259   // constrained.
LayoutForChannel(int64 channel_id)260   const Layout& LayoutForChannel(int64 channel_id) const {
261     auto it = constraints_.find(channel_id);
262     CHECK(it != constraints_.end()) << "Channel " << channel_id;
263     return it->second;
264   }
265 
266   // Adds a new layout constraint for `channel_id`. If a constraint for
267   // `channel_id` has been added, this API returns nullptr, otherwise returns
268   // the layout which has already been set for the channel.
ConstrainChannel(int64 channel_id,const Layout & layout)269   const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) {
270     auto it = constraints_.emplace(std::make_pair(channel_id, layout));
271     if (it.second) {
272       return nullptr;
273     }
274     return LayoutUtil::Equal(layout, it.first->second) ? nullptr
275                                                        : &it.first->second;
276   }
277 
278  private:
279   absl::flat_hash_map<int64, Layout> constraints_;
280 };
281 
282 // HLO pass which assigns layouts to all instructions in the HLO module while
283 // satisfying all necessary invariants and minimizing cost.
284 class LayoutAssignment : public HloModulePass {
285  public:
286   // entry_computation_layout is modified to populate a layout for the result in
287   // the case that no particular layout is requested.
288   //
289   // instruction_can_change_layout_func is a function object that determines
290   // whether an instruction can change layouts. An instruction not being able to
291   // change layout means that it requires operands with the same rank as the
292   // output to have the same layout as the output.
293   //
294   // channel_constraints is both an input and output. Any sends or recvs that
295   // are present in channel_constraints will be laid out as constrained. Any
296   // unconstrained sends or recvs will be laid out as locally optimal and their
297   // layout will be added as a constraint to channel_constraints.
298   //
299   // If channel_constraints is nullptr, no kSend or kRecvs must be contained
300   // within any module passed to `Run`.
301   explicit LayoutAssignment(
302       ComputationLayout* entry_computation_layout,
303       std::function<bool(const HloInstruction*)>
304           instruction_can_change_layout_func = InstructionCanChangeLayout,
305       ChannelLayoutConstraints* channel_constraints = nullptr);
~LayoutAssignment()306   ~LayoutAssignment() override {}
name()307   absl::string_view name() const override { return "layout-assignment"; }
308 
309   // Assign layouts to the given module. Returns whether the module was changed
310   // (any layouts were changed).
311   StatusOr<bool> Run(HloModule* module) override;
312 
313   // Determines whether an instruction can change layouts. An instruction not
314   // being able to change layout means that it requires operands with the same
315   // rank as the output to have the same layout as the output.
316   static bool InstructionCanChangeLayout(const HloInstruction* instruction);
317 
318   // In case of an array shape returns true iff it is at most rank 1. In case of
319   // a tuple shape returns true iff all leaf shapes are at most rank 1.
320   static bool IsAtMostRank1(const Shape& shape);
321 
322  protected:
323   // These methods, invoked by PropagateConstraints, propagate a layout
324   // constraint to its neighbors (i.e. operands and users) in order to minimize
325   // the cost of the instructions being constrainted on. New constraints are
326   // added to the given constraint set.
327   //
328   // Backends can override these methods with backend-specific propagation
329   // rules.
330   virtual Status PropagateBufferConstraint(
331       const BufferLayoutConstraint& layout_constraint,
332       LayoutConstraints* constraints);
333   virtual Status PropagateOperandConstraint(
334       const OperandLayoutConstraint& layout_constraint,
335       LayoutConstraints* constraints);
336   virtual Status PropagateResultConstraint(
337       const ResultLayoutConstraint& layout_constraint,
338       LayoutConstraints* constraints);
339 
340   // Called after layouts of an instruction have been finalized to allow
341   // subclasses to check for platform specific assumptions.
Verify(const HloInstruction * instruction)342   virtual Status Verify(const HloInstruction* instruction) {
343     return Status::OK();
344   }
345 
346   // Propagates a buffer layout constraint into the operands that use it.
347   Status PropagateBufferConstraintToUses(
348       const BufferLayoutConstraint& layout_constraint,
349       LayoutConstraints* constraints);
350 
351   // Propagates a layout constraint on the use of the result of the given
352   // instruction to the definitions of the LogicalBuffers which make up the
353   // result.
354   Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout,
355                                       const HloInstruction* instruction,
356                                       LayoutConstraints* constraints);
357 
358   // Chooses a layout of operand `operand_no` of `instruction` that minimizes
359   // the cost of `instruction`. `output_layout` is the layout of `instruction`.
360   // Returns null if it can't decide the best layout.
361   // Precondition: `instruction` and the operand are array-shaped.
362   std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout(
363       const Layout& output_layout, const HloInstruction* instruction,
364       int64 operand_no);
365   // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of
366   // `user` that minimizes its cost on that operand.  Returns null if it can't
367   // decide the best layout.
368   // Precondition: `user` and the operand are array-shaped.
369   virtual std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout(
370       const Layout& operand_layout, const HloInstruction* user,
371       int64 operand_no);
372 
373  private:
374   // Initializes the layout assignment object for a new Run() call.
375   Status Init();
376 
377   // Adds constraints which must be satisfied for correctness on all
378   // backends. Called once prior to propagating constraints.
379   Status AddMandatoryConstraints(const ComputationLayout* computation_layout,
380                                  ChannelLayoutConstraints* channel_constraints,
381                                  HloComputation* computation,
382                                  LayoutConstraints* constraints);
383 
384   // This method can be overridden to add backend-specific constraints to the
385   // layout of the instructions of a computation. This method is called after
386   // all mandatory constraints have been added via AddMandatoryConstraints
387   // and before propagating constraints.
AddBackendConstraints(LayoutConstraints * constraints)388   virtual Status AddBackendConstraints(LayoutConstraints* constraints) {
389     return Status::OK();
390   }
391 
392   // Construct contraints and assign layouts to all instructions in the
393   // computation satisfying the given ComputationLayout, if not nullptr.
394   // Otherwise the ComputationLayout will be calculated by propagating the
395   // computation instruction contraints.
396   // Layouts constraints are added, then propagated until all LogicalBuffers in
397   // the computation are constrained.
398   Status RunOnComputation(ComputationLayout* computation_layout,
399                           const TuplePointsToAnalysis& points_to_analysis,
400                           HloComputation* computation,
401                           ChannelLayoutConstraints* channel_constraints);
402 
403   // Assign layouts to the instructions of a computation which satisfy the given
404   // layout constraints. Copies may be added to satisfy the constraints. The
405   // given LayoutConstraints must have layout constraints every logical buffer
406   // in the computation.
407   Status AssignLayouts(const LayoutConstraints& constraints,
408                        HloComputation* computation);
409 
410   // Propagates layout constraints from a set of initial constraints in order to
411   // minimize the local cost of the computation. This propagation is *not*
412   // required for correctness.
413   Status PropagateConstraints(LayoutConstraints* constraints);
414 
415   Status PropagateBufferConstraintToOperands(
416       const BufferLayoutConstraint& buffer_constraint,
417       LayoutConstraints* constraints);
418 
419   // Check that all layouts in the module have been set and satisfy all
420   // necessary conditions.
421   Status CheckLayouts(HloModule* module);
422 
423   // Computes the ComputationLayout of the given computation based of the
424   // layouts assigned to parameters and root instruction, and inserts it to the
425   // computation_layouts_ map.
426   Status CalculateComputationLayout(HloComputation* computation);
427 
428   // Clears all the layouts which can be cleared within a computation.
429   Status ClearComputationLayouts(HloComputation* computation);
430 
431   // Clears the side effects of a previous pass, like added copy instructions.
432   Status ClearPreviousPassSideEffects(HloModule* module);
433 
434   // Propagates the layouts computed by the layout assignment pass on the given
435   // computation, to the computation layout passed in to this API.
436   // This API propagates missing layout, and also checks that the caller
437   // specified have been respected, by comparing those with the parameters and
438   // root computation instruction.
439   Status PropagateComputationLayouts(HloComputation* computation,
440                                      ComputationLayout* computation_layout);
441 
442   // The pointer to the ComputationLayout passed as constructor parameter.
443   ComputationLayout* entry_computation_layout_;
444 
445   // A copy of entry_computation_layout_ used to reset it to the initial values
446   // during the multiple passes done by the layout assignment operation.
447   ComputationLayout saved_entry_computation_layout_;
448 
449  protected:
450   // Sets up the copy instruction according to the characteristic (sharding,
451   // metadata, ...) of the reference instruction. The index argument is used
452   // when the instruction is a tuple, and in such case the index represents
453   // the location from where the copy instruction was created from.
454   // If the index is empty, the whole sharding will be propagated, even in case
455   // the intruction has a tuple sharding.
456   static void SetupCopiedInstruction(const HloInstruction& instruction,
457                                      HloInstruction* copy,
458                                      const ShapeIndex& index);
459 
460   // Creates and returns a copy of the given instruction with a different
461   // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
462   // instruction producing the copy is returned.
463   StatusOr<HloInstruction*> CreateCopyWithNewLayout(
464       const Shape& shape_with_layout, HloInstruction* instruction);
465 
466   // Creates a copy of the given operand if the operand's layout does not match
467   // the given layout. This copy replaces the use in the given instruction.
468   // Tuple operands will be deep-copied.
469   Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
470                                     HloInstruction* instruction,
471                                     int64 operand_no);
472 
473   // Registers a copy instruction added by the layout assignment pass.
RegisterAddedCopy(HloInstruction * copy)474   void RegisterAddedCopy(HloInstruction* copy) {
475     CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
476     added_copies_.insert(copy);
477   }
478 
479   // Adds a copy for the operand of an instruction, unless such operand is
480   // already a copy, and has a single user (which is forcibly the instruction
481   // itself).
482   Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
483 
484   // Apply the channel layout constraints by populating the channel_constraints
485   // data structure passed in at constructor time. Eventually adds copies in
486   // case two ends of a channel ended up with a different leyout.
487   Status ConstrainChannelLayouts(HloComputation* computation,
488                                  ChannelLayoutConstraints* channel_constraints);
489 
490   // Resets the input ChannelLayoutConstraints to the original copy received
491   // from the constructor input.
ResetChannelConstraints()492   void ResetChannelConstraints() {
493     if (channel_layout_constraints_ != nullptr) {
494       *channel_layout_constraints_ = channel_constraints_;
495     }
496   }
497 
498   // Adds constraints related to host Send/Recv instructions.
499   Status BuildHostChannelConstraints(HloComputation* computation);
500 
501   // Map containing the layouts of all computations assigned so
502   // far. Computations are handled in a topological sort where computations are
503   // handled before their caller instructions so the layouts of caller
504   // instructions can be set to match the computation.
505   std::map<HloComputation*, ComputationLayout> computation_layouts_;
506 
507   // Every copy added to the module by the layout assignment pass is registered
508   // here.
509   absl::flat_hash_set<HloInstruction*> added_copies_;
510 
511   // The pointer to the channel layout constraints passed in with the
512   // constructor. If not nullptr, this is an input/output argument.
513   ChannelLayoutConstraints* channel_layout_constraints_ = nullptr;
514 
515   // A copy of the input layout constraints used to reset the above pointer in
516   // case we have to undo operations due to the multiple passes over the
517   // computations/instructions.
518   ChannelLayoutConstraints channel_constraints_;
519 
520   // Layout constraints for send/recv instructions which communicate with the
521   // host.
522   ChannelLayoutConstraints host_channel_constraints_;
523 
524   // The set of HLO instructions which lacked any layout constraint, thus
525   // receiving propagated default layouts.
526   absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
527 
528   std::function<bool(const HloInstruction*)>
529       instruction_can_change_layout_func_;
530 };
531 
532 }  // namespace xla
533 
534 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
535