• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/permutation_util.h"
38 #include "tensorflow/compiler/xla/service/call_graph.h"
39 #include "tensorflow/compiler/xla/service/computation_layout.h"
40 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_dce.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/logical_buffer.h"
48 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
49 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
50 #include "tensorflow/compiler/xla/shape_layout.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/protobuf.h"
61 
62 namespace xla {
63 
operator <<(std::ostream & out,const LayoutConstraint & constraint)64 std::ostream& operator<<(std::ostream& out,
65                          const LayoutConstraint& constraint) {
66   out << constraint.ToString();
67   return out;
68 }
69 
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs,int64_t priority)70 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
71                                                const LogicalBuffer& buffer,
72                                                bool mandatory, bool dfs,
73                                                int64_t priority)
74     : LayoutConstraint(mandatory, dfs, priority),
75       layout_(layout),
76       buffer_(&buffer) {
77   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
78 }
79 
ToString() const80 string BufferLayoutConstraint::ToString() const {
81   return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
82                          LayoutUtil::HumanString(layout_));
83 }
84 
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)85 OperandLayoutConstraint::OperandLayoutConstraint(
86     const ShapeLayout& shape_layout, const HloInstruction* instruction,
87     int64_t operand_no, bool mandatory, bool dfs, int64_t priority)
88     : LayoutConstraint(mandatory, dfs, priority),
89       shape_layout_(shape_layout),
90       instruction_(instruction),
91       operand_no_(operand_no) {
92   CHECK(shape_layout_.LayoutIsSet());
93   CHECK(ShapeUtil::Compatible(shape_layout.shape(),
94                               instruction->operand(operand_no)->shape()))
95       << shape_layout.shape() << " is not compatible with "
96       << instruction->operand(operand_no)->shape() << " (for operand "
97       << operand_no << " of instruction " << instruction->ToString() << ")";
98 }
99 
ToString() const100 string OperandLayoutConstraint::ToString() const {
101   return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
102                          instruction_->name(), operand_no_,
103                          shape_layout_.ToString());
104 }
105 
ToString() const106 string ResultLayoutConstraint::ToString() const {
107   return absl::StrFormat("ResultLayoutConstraint: %s",
108                          shape_layout_.ToString());
109 }
110 
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)111 LayoutConstraints::LayoutConstraints(
112     const TuplePointsToAnalysis& points_to_analysis,
113     HloComputation* computation)
114     : points_to_analysis_(points_to_analysis), computation_(computation) {
115   // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
116   for (HloInstruction* inst : computation_->instructions()) {
117     points_to_analysis_.GetPointsToSet(inst).ForEachElement(
118         [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
119           for (const LogicalBuffer* buffer : buffers) {
120             // The points to analysis is computed per module, restrict
121             // constraints to array buffers in this computation.
122             if (buffer->IsArray() &&
123                 buffer->instruction()->parent() == computation) {
124               unconstrained_buffer_ids_.insert(buffer->id());
125             }
126           }
127         });
128   }
129 }
130 
GetBufferSet(const HloInstruction * instruction) const131 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
132     const HloInstruction* instruction) const {
133   auto it = buffer_sets_cache_.find(instruction);
134   if (it != buffer_sets_cache_.end()) {
135     return it->second.get();
136   }
137   auto& buffer_set =
138       buffer_sets_cache_
139           .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
140           .first->second;
141   const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
142   points_to_set.ForEachElement(
143       [&buffer_set](const ShapeIndex& /*index*/,
144                     const PointsToSet::BufferList& buffers) {
145         buffer_set->insert(buffers.begin(), buffers.end());
146       });
147   return buffer_set.get();
148 }
149 
AnyOperandBufferForwarded(const HloInstruction * instruction,int64_t operand_no) const150 bool LayoutConstraints::AnyOperandBufferForwarded(
151     const HloInstruction* instruction, int64_t operand_no) const {
152   // The operand is potentially forwarded if the intersection of points-to sets
153   // of the operand and the instruction is non-empty.
154   PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
155   PointsToSet::BufferSet* operand_buffers =
156       GetBufferSet(instruction->operand(operand_no));
157   return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
158     return operand_buffers->count(b) > 0;
159   });
160 }
161 
AllOperandBuffersForwarded(const HloInstruction * instruction,int64_t operand_no) const162 bool LayoutConstraints::AllOperandBuffersForwarded(
163     const HloInstruction* instruction, int64_t operand_no) const {
164   // The operand is potentially forwarded if the intersection of points-to sets
165   // of the operand and the instruction is non-empty.
166   PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
167   PointsToSet::BufferSet* operand_buffers =
168       GetBufferSet(instruction->operand(operand_no));
169   return absl::c_all_of(*output_buffers, [&](const LogicalBuffer* b) {
170     return operand_buffers->count(b) > 0;
171   });
172 }
173 
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)174 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
175                                           const LogicalBuffer& buffer,
176                                           bool mandatory, bool dfs) {
177   VLOG(3) << "SetBufferLayout : " << buffer << " : "
178           << LayoutUtil::HumanString(layout);
179 
180   TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
181   if (!buffer.IsArray()) {
182     return FailedPrecondition(
183         "Layout of buffer %s cannot be constrained because buffer is not "
184         "array-shaped, has shape: %s",
185         buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
186   }
187   TF_RETURN_IF_ERROR(
188       LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
189 
190   auto iter = buffer_constraints_.find(&buffer);
191   if (iter != buffer_constraints_.end()) {
192     const BufferLayoutConstraint& curr_constraint = iter->second;
193     if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
194       // New constraint matches existing constraint. Nothing to do.
195       return Status::OK();
196     }
197     if (curr_constraint.mandatory()) {
198       if (!mandatory) {
199         VLOG(3) << "Buffer" << buffer
200                 << " already has a mandatory layout constrain, skipping";
201         return Status::OK();
202       }
203       return FailedPrecondition(
204           "Buffer %s already has the layout constraint %s, cannot add "
205           "incompatible constraint %s",
206           buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
207           LayoutUtil::HumanString(layout));
208     }
209     iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
210   } else {
211     TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
212         << buffer.ToString();
213     iter = buffer_constraints_
214                .insert(std::make_pair(
215                    &buffer,
216                    BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
217                .first;
218   }
219   added_constraints_.push_back(&iter->second);
220   return Status::OK();
221 }
222 
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs)223 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
224                                            const HloInstruction* instruction,
225                                            int64_t operand_no, bool mandatory,
226                                            bool dfs) {
227   auto priority = LayoutConstraint::kDefaultPriority;
228   // The second and third operands (operand_no > 0) of a dynamic-update-slice
229   // operation typically have much smaller sizes than the first (operand_no==0)
230   // operand. It is necessary to downgrade the importance of the smaller
231   // operands, so that the overall layout choice of the operation is dicated by
232   // operand 0 when possible.
233   if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice &&
234       operand_no > 0 && !mandatory) {
235     dfs = false;
236     priority--;
237   }
238   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
239           << operand_no << " : "
240           << ShapeUtil::HumanStringWithLayout(shape_with_layout)
241           << " : priority = " << priority << "\n";
242   const OperandLayoutConstraint* curr_shape_layout =
243       GetOperandLayoutConstraint(instruction, operand_no);
244   if (curr_shape_layout != nullptr) {
245     if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
246             shape_with_layout, /*minor_to_major_only=*/true)) {
247       // New constraint matches existing constraint. Nothing to do.
248       return Status::OK();
249     }
250     if (curr_shape_layout->mandatory()) {
251       return FailedPrecondition(
252           "Operand %d of instruction %s already has a layout constraint "
253           "%s, cannot add incompatible constraint %s",
254           operand_no, instruction->name(),
255           curr_shape_layout->shape_layout().ToString(),
256           ShapeUtil::HumanStringWithLayout(shape_with_layout));
257     }
258   }
259 
260   // If any buffers in the operand occur in the output of the instruction, then
261   // return an error. This case is not handled because such a constraint changes
262   // layouts beyond this immediate use and is complicated to handle.
263   if (AnyOperandBufferForwarded(instruction, operand_no)) {
264     return FailedPrecondition(
265         "Cannot constraint layout of operand %d of instruction %s "
266         "because instruction forwards operand's LogicalBuffer(s)",
267         operand_no, instruction->name());
268   }
269 
270   auto key = std::make_pair(instruction, operand_no);
271   auto iter = operand_constraints_.find(key);
272   if (iter == operand_constraints_.end()) {
273     auto pair = std::make_pair(
274         key,
275         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
276                                 operand_no, mandatory, dfs, priority));
277     iter = operand_constraints_.insert(pair).first;
278   } else {
279     iter->second =
280         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
281                                 operand_no, mandatory, dfs, priority);
282   }
283   added_constraints_.push_back(&iter->second);
284   // Move the new constraint from the end of the worklist forward if its
285   // priority is higher than those ahead of it in the worklist. Assuming the
286   // worklist is already sorted, this works like a bubble sort, where the lower
287   // priority constraints are always pushed to the tail of the worklist.
288   for (int64_t i = added_constraints_.size() - 2; i >= 0; --i) {
289     if (added_constraints_[i]->priority() <
290         added_constraints_[i + 1]->priority()) {
291       std::swap(added_constraints_[i + 1], added_constraints_[i]);
292     } else {
293       break;
294     }
295   }
296   return Status::OK();
297 }
298 
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs)299 Status LayoutConstraints::SetArrayOperandLayout(
300     const Layout& layout, const HloInstruction* instruction, int64_t operand_no,
301     bool mandatory, bool dfs) {
302   const HloInstruction* operand = instruction->operand(operand_no);
303   TF_RET_CHECK(operand->shape().IsArray());
304   Shape shape(operand->shape());
305   *shape.mutable_layout() = layout;
306   TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
307   return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
308 }
309 
SetResultLayout(const Shape & shape_with_layout,bool dfs)310 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
311                                           bool dfs) {
312   VLOG(3) << "SetResultLayout : "
313           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
314 
315   const ShapeLayout* curr_shape_layout = ResultLayout();
316   if (curr_shape_layout != nullptr) {
317     if (!curr_shape_layout->MatchesLayoutInShape(
318             shape_with_layout, /*minor_to_major_only=*/true)) {
319       return FailedPrecondition(
320           "Result of computation %s already has the layout constraint %s, "
321           "cannot add incompatible constraint %s",
322           computation_->name(), curr_shape_layout->ToString(),
323           ShapeUtil::HumanStringWithLayout(shape_with_layout));
324     }
325     // New constraint matches existing constraint. Nothing to do.
326     return Status::OK();
327   }
328   result_constraint_.reset(
329       new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
330   added_constraints_.push_back(result_constraint_.get());
331 
332   return Status::OK();
333 }
334 
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs,bool allow_alias)335 Status LayoutConstraints::SetInstructionLayout(
336     const Shape& shape_with_layout, const HloInstruction* instruction,
337     bool mandatory, bool dfs, bool allow_alias) {
338   VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
339           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
340 
341   if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
342     return FailedPrecondition(
343         "Instruction %s of shape %s cannot be assigned incompatible layout %s",
344         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
345         ShapeUtil::HumanStringWithLayout(shape_with_layout));
346   }
347 
348   // Create a BufferLayoutConstraint for each array shape in the output of the
349   // instruction.
350   return ShapeUtil::ForEachSubshapeWithStatus(
351       shape_with_layout,
352       [this, instruction, mandatory, allow_alias](
353           const Shape& subshape, const ShapeIndex& index) -> Status {
354         auto buffers =
355             points_to_analysis_.GetPointsToSet(instruction).element(index);
356         CHECK_EQ(1, buffers.size());
357         if (!allow_alias) {
358           CHECK_EQ(buffers[0]->instruction(), instruction);
359         }
360 
361         if (subshape.IsArray() && subshape.has_layout()) {
362           return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
363         } else {
364           return Status::OK();
365         }
366       });
367 }
368 
BufferLayout(const LogicalBuffer & buffer) const369 const Layout* LayoutConstraints::BufferLayout(
370     const LogicalBuffer& buffer) const {
371   if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
372     return &constraint->layout();
373   }
374   return nullptr;
375 }
376 
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const377 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
378     const LogicalBuffer& buffer) const {
379   auto it = buffer_constraints_.find(&buffer);
380   return it == buffer_constraints_.end() ? nullptr : &it->second;
381 }
382 
OperandLayout(const HloInstruction * instruction,int64_t operand_no) const383 const ShapeLayout* LayoutConstraints::OperandLayout(
384     const HloInstruction* instruction, int64_t operand_no) const {
385   if (const auto* constraint =
386           GetOperandLayoutConstraint(instruction, operand_no)) {
387     return &constraint->shape_layout();
388   }
389   return nullptr;
390 }
391 
GetOperandLayoutConstraint(const HloInstruction * instruction,int64_t operand_no) const392 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
393     const HloInstruction* instruction, int64_t operand_no) const {
394   auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
395   return it == operand_constraints_.end() ? nullptr : &it->second;
396 }
397 
ResultLayout() const398 const ShapeLayout* LayoutConstraints::ResultLayout() const {
399   return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
400 }
401 
ToString() const402 string LayoutConstraints::ToString() const {
403   string output;
404   absl::StrAppend(&output, "LayoutConstraints for computation ",
405                   computation_->name(), ":\n");
406   for (auto* instruction : computation_->MakeInstructionPostOrder()) {
407     absl::StrAppend(&output, "  ", instruction->ToShortString(), "\n");
408     for (int64_t i = 0; i < instruction->operand_count(); ++i) {
409       if (OperandLayout(instruction, i) != nullptr) {
410         absl::StrAppend(&output, "    operand (", i,
411                         "): ", OperandLayout(instruction, i)->ToString(), "\n");
412       }
413     }
414     for (const LogicalBuffer* buffer :
415          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
416       if (BufferLayout(*buffer) != nullptr) {
417         absl::StrAppend(&output, "    ", buffer->ToString(), " : ",
418                         LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
419       }
420     }
421   }
422 
423   if (ResultLayout() != nullptr) {
424     absl::StrAppend(&output, "  => ", ResultLayout()->ToString(), "\n");
425   }
426   return output;
427 }
428 
429 namespace {
430 
IsHostSendRecv(const HloInstruction * instruction)431 bool IsHostSendRecv(const HloInstruction* instruction) {
432   const HloSendRecvInstruction* send_recv_instr =
433       DynCast<HloSendRecvInstruction>(instruction);
434   return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
435 }
436 
437 }  // namespace
438 
BuildHostChannelConstraints(HloComputation * computation)439 Status LayoutAssignment::BuildHostChannelConstraints(
440     HloComputation* computation) {
441   for (auto* instruction : computation->instructions()) {
442     const HloSendRecvInstruction* send_recv_instr =
443         DynCast<HloSendRecvInstruction>(instruction);
444     if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
445       continue;
446     }
447 
448     // For host transfers the Send and Recv instruction carry the layout.
449     if (instruction->opcode() == HloOpcode::kSend ||
450         instruction->opcode() == HloOpcode::kRecv) {
451       const Shape& data_shape =
452           ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
453       TF_RET_CHECK(data_shape.IsArray());
454       TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
455       const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
456           *send_recv_instr->channel_id(), data_shape.layout());
457       TF_RET_CHECK(prev_layout == nullptr)
458           << "Cannot constrain host transfer layout as it was set to "
459           << LayoutUtil::HumanString(*prev_layout) << ": "
460           << send_recv_instr->ToString();
461     }
462   }
463   return Status::OK();
464 }
465 
466 namespace {
467 
IsLayoutConstrainedCustomCall(HloInstruction * instruction)468 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
469   const HloCustomCallInstruction* custom_call =
470       DynCast<HloCustomCallInstruction>(instruction);
471   return custom_call != nullptr && custom_call->layout_constrained();
472 }
473 
IsLayoutConstrainedCollective(const HloInstruction * instruction)474 bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
475   const HloCollectiveInstruction* collective =
476       DynCast<HloCollectiveInstruction>(instruction);
477   return collective != nullptr && collective->constrain_layout();
478 }
479 
PropagateParameterLayoutToUsers(const HloInstruction * instruction,const Shape & shape,LayoutConstraints * constraints)480 Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
481                                        const Shape& shape,
482                                        LayoutConstraints* constraints) {
483   for (auto* user : instruction->users()) {
484     // Excluding tuple operations as they do not participate in layout
485     // propagations (they do not create or aliase buffers).
486     if (user->opcode() == HloOpcode::kTuple) {
487       continue;
488     }
489     VLOG(3) << "Setting  user layout : " << user->ToString();
490     if (user->opcode() == HloOpcode::kGetTupleElement) {
491       auto tuple_index = user->tuple_index();
492       CHECK(shape.IsTuple());
493       auto elem_shape = shape.tuple_shapes(tuple_index);
494       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
495           elem_shape, user, /*mandatory=*/false, /*dfs=*/false,
496           /*allow_alias=*/true));
497       TF_RETURN_IF_ERROR(
498           PropagateParameterLayoutToUsers(user, elem_shape, constraints));
499     } else {
500       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
501           shape, user, user->operand_index(instruction), /*mandatory=*/false,
502           /*dfs=*/false));
503     }
504   }
505   return Status::OK();
506 }
507 
508 }  // namespace
509 
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)510 Status LayoutAssignment::AddMandatoryConstraints(
511     const ComputationLayout* computation_layout,
512     ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
513     LayoutConstraints* constraints) {
514   VLOG(2) << "Adding mandatory layout constraints to computation "
515           << computation->name();
516 
517   auto get_channel_constraints = [&](const HloInstruction* instruction) {
518     return IsHostSendRecv(instruction) ? &host_channel_constraints_
519                                        : channel_constraints;
520   };
521 
522   // Constrain layouts of instructions which define values with pre-existing
523   // layouts.
524   for (auto* instruction : computation->instructions()) {
525     if (instruction->opcode() == HloOpcode::kInfeed) {
526       // Infeed layouts must match the layout of the original inserted
527       // instruction.
528       // TODO(b/31425034): Change infeeds to be more like parameters, with
529       // shapes in the ComputationLayout.
530       TF_RETURN_IF_ERROR(
531           constraints->SetInstructionLayout(instruction->shape(), instruction));
532     } else if (instruction->opcode() == HloOpcode::kOutfeed) {
533       // Constrain the input to the Outfeed instruction to be the expected
534       // layout of the Outfeed.
535       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
536           instruction->outfeed_shape(), instruction, 0));
537     } else if (instruction->opcode() == HloOpcode::kParameter) {
538       if (computation_layout != nullptr) {
539         const ShapeLayout& parameter_layout =
540             computation_layout->parameter_layout(
541                 instruction->parameter_number());
542         // Parameter layouts must match the respective layout in
543         // ComputationLayout, if there is one.
544         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
545             parameter_layout.shape(), instruction));
546         if (reverse_computation_order_) {
547           TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers(
548               instruction, parameter_layout.shape(), constraints));
549         }
550       }
551     } else if (IsLayoutConstrainedCustomCall(instruction)) {
552       const HloCustomCallInstruction* custom_call =
553           DynCast<HloCustomCallInstruction>(instruction);
554 
555       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
556           custom_call->shape(), custom_call, /*mandatory=*/true, /*dfs=*/true,
557           /*allow_alias=*/true));
558 
559       for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
560         if (constraints->AnyOperandBufferForwarded(custom_call, i)) {
561           TF_RET_CHECK(constraints->AllOperandBuffersForwarded(custom_call, i))
562               << "Partial alias of an operand is not supported";
563         } else {
564           TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
565               custom_call->operand_shapes_with_layout()[i], custom_call, i));
566         }
567       }
568     } else if (instruction->opcode() == HloOpcode::kSend ||
569                instruction->opcode() == HloOpcode::kRecv) {
570       CHECK(get_channel_constraints(instruction))
571           << "Multi-module layout assignment requires ChannelLayoutConstraints";
572       int64_t channel_id = *instruction->channel_id();
573       if (!get_channel_constraints(instruction)
574                ->IsChannelConstrained(channel_id)) {
575         continue;
576       }
577       if (instruction->opcode() == HloOpcode::kSend) {
578         // TODO(b/68493863): Change to use SetOperandLayout().
579         const Shape send_buffer_shape = instruction->operand(0)->shape();
580         TF_RET_CHECK(send_buffer_shape.IsArray());
581         Shape new_buffer_shape =
582             get_channel_constraints(instruction)
583                 ->LayoutShapeForChannel(send_buffer_shape,
584                                         *instruction->channel_id());
585         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
586             new_buffer_shape, instruction->operand(0)));
587       } else {
588         const Shape recv_buffer_shape =
589             ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
590         TF_RET_CHECK(recv_buffer_shape.IsArray());
591         TF_ASSIGN_OR_RETURN(
592             const LogicalBuffer* buffer,
593             constraints->points_to_analysis().GetBufferDefinedAt(instruction,
594                                                                  {0}));
595         Shape new_shape =
596             get_channel_constraints(instruction)
597                 ->LayoutShapeForChannel(recv_buffer_shape,
598                                         *instruction->channel_id());
599         TF_RETURN_IF_ERROR(
600             constraints->SetBufferLayout(new_shape.layout(), *buffer));
601       }
602     } else if (IsLayoutConstrainedCollective(instruction)) {
603       TF_RETURN_IF_ERROR(
604           constraints->SetInstructionLayout(instruction->shape(), instruction));
605     } else if (instruction->IsCrossModuleAllReduce()) {
606       CHECK(get_channel_constraints(instruction))
607           << "Multi-module layout assignment requires ChannelLayoutConstraints";
608       int64_t channel_id = instruction->channel_id().value();
609       if (!get_channel_constraints(instruction)
610                ->IsChannelConstrained(channel_id)) {
611         continue;
612       }
613       // TODO(b/68493863): Change to use SetOperandLayout().
614       const Shape& buffer_shape = instruction->operand(0)->shape();
615       TF_RET_CHECK(buffer_shape.IsArray());
616       Shape new_buffer_shape =
617           get_channel_constraints(instruction)
618               ->LayoutShapeForChannel(buffer_shape, channel_id);
619       TF_RETURN_IF_ERROR(
620           constraints->SetInstructionLayout(new_buffer_shape, instruction));
621     }
622   }
623 
624   // Constrain layouts of instructions which call computations which have
625   // already been assigned layouts. Instructions which call computations in a
626   // parallel element-wise context (eg, map or reduce) do not need layout
627   // constraints because they operate on scalars.
628   for (auto* instruction : computation->instructions()) {
629     if (instruction->opcode() == HloOpcode::kCall &&
630         computation_layouts_.find(instruction->to_apply()) !=
631             computation_layouts_.end()) {
632       // kCall instruction operands and output must match the ComputationLayout
633       // of the called computation.
634       const ComputationLayout& called_computation_layout =
635           FindOrDie(computation_layouts_, instruction->to_apply());
636       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
637           called_computation_layout.result_layout().shape(), instruction));
638       TF_RET_CHECK(instruction->operand_count() ==
639                    called_computation_layout.parameter_count());
640       for (int64_t i = 0; i < instruction->operand_count(); ++i) {
641         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
642             called_computation_layout.parameter_layout(i).shape(), instruction,
643             i));
644       }
645     } else if (instruction->opcode() == HloOpcode::kWhile &&
646                computation_layouts_.find(instruction->while_body()) !=
647                    computation_layouts_.end()) {
648       // Layout of input and output of kWhile instruction must be equal and must
649       // match both input and output of body computation. Also, the input of
650       // condition computation must match kWhile layout.
651       HloComputation* body = instruction->while_body();
652       HloComputation* condition = instruction->while_condition();
653       const HloInstruction* init = instruction->operand(0);
654       ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
655       ComputationLayout& condition_layout =
656           FindOrDie(computation_layouts_, condition);
657 
658       // Check a few invariants irrespective of layout.
659       CHECK_EQ(1, instruction->operand_count());
660       CHECK_EQ(1, body->num_parameters());
661       CHECK_EQ(1, condition->num_parameters());
662       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
663                                    body_layout.parameter_shape(0)));
664       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
665                                    condition_layout.parameter_shape(0)));
666       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
667 
668       if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
669         VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
670                 << " while=" << instruction->name()
671                 << " shape=" << body_layout.result_layout().ToString();
672         *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
673       }
674       if (condition_layout.parameter_layout(0) !=
675           body_layout.parameter_layout(0)) {
676         VLOG(2) << "Reset %while condition parameter layout: cond="
677                 << condition->name() << " while=" << instruction->name()
678                 << " shape=" << body_layout.parameter_layout(0).ToString();
679         *condition_layout.mutable_parameter_layout(0) =
680             body_layout.parameter_layout(0);
681       }
682 
683       // Constrain the output and the operand of the while instruction to match
684       // the computations.
685       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
686           body_layout.result_shape(), instruction, 0));
687       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
688           body_layout.result_shape(), instruction));
689     } else if (instruction->opcode() == HloOpcode::kConditional &&
690                computation_layouts_.find(instruction->branch_computation(0)) !=
691                    computation_layouts_.end()) {
692       // Find the conditional branch with the most instructions and force all
693       // other computations to match that layout. A potentially better decision
694       // could count the number FLOPs or how constrained the layouts are.
695       int64_t largest_branch = -1;
696       int64_t largest_instruction_count = 0;
697       for (int j = 0; j < instruction->branch_count(); ++j) {
698         const int64_t instruction_count =
699             instruction->branch_computation(j)->instruction_count();
700         if (instruction_count > largest_instruction_count &&
701             !ShapeUtil::IsEmptyTuple(instruction->operand(j + 1)->shape())) {
702           largest_branch = j;
703           largest_instruction_count = instruction_count;
704         }
705       }
706       if (largest_branch == -1) {
707         largest_branch = 0;
708       }
709       ComputationLayout& best_branch_computation_layout =
710           FindOrDie(computation_layouts_,
711                     instruction->branch_computation(largest_branch));
712       for (int k = 0; k < instruction->branch_count(); ++k) {
713         // Visit the best branch first.
714         int j = (k + largest_branch) % instruction->branch_count();
715         TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
716         ComputationLayout& branch_computation_layout =
717             FindOrDie(computation_layouts_, instruction->branch_computation(k));
718         if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
719                 best_branch_computation_layout.result_layout().shape(),
720                 /*minor_to_major_only=*/true)) {
721           computation_layouts_.erase(instruction->branch_computation(k));
722           InsertOrDie(&conditional_mismatch_,
723                       instruction->branch_computation(k),
724                       best_branch_computation_layout);
725         } else {
726           TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
727               branch_computation_layout.parameter_shape(0), instruction, k + 1,
728               /*mandatory=*/true));
729         }
730       }
731       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
732           best_branch_computation_layout.parameter_shape(0), instruction,
733           largest_branch + 1,
734           /*mandatory=*/true));
735       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
736           best_branch_computation_layout.result_shape(), instruction));
737     }
738   }
739   // Finally set the result layout to match ComputationLayout, if there is one.
740   if (conditional_mismatch_.count(computation) > 0) {
741     TF_RETURN_IF_ERROR(constraints->SetResultLayout(
742         FindOrDie(conditional_mismatch_, computation).result_layout().shape()));
743   } else if (computation_layout != nullptr) {
744     const ShapeLayout& result_layout = computation_layout->result_layout();
745     if (result_layout.LayoutIsSet()) {
746       TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
747     }
748   }
749   return Status::OK();
750 }
751 
752 namespace {
753 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)754 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
755   return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
756 }
757 
758 // The operands of a call must match the layouts of parameters in the
759 // ComputationLayout, and the call instruction itself must match the result
760 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)761 Status CheckCallLayout(HloInstruction* call,
762                        const ComputationLayout& computation_layout) {
763   HloComputation* computation = call->to_apply();
764   TF_RET_CHECK(computation->num_parameters() == call->operand_count());
765   for (int64_t i = 0; i < computation->num_parameters(); ++i) {
766     TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
767         call->operand(i)->shape(), /*minor_to_major_only=*/true));
768   }
769   TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
770       call->shape(), /*minor_to_major_only=*/true));
771   return Status::OK();
772 }
773 
774 // Operands of layout-constrained custom calls must match the expected
775 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)776 Status CheckCustomCallLayout(HloInstruction* instruction) {
777   if (IsLayoutConstrainedCustomCall(instruction)) {
778     const HloCustomCallInstruction* custom_call =
779         DynCast<HloCustomCallInstruction>(instruction);
780     for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
781       TF_RET_CHECK(
782           LayoutsInShapesEqual(custom_call->operand(i)->shape(),
783                                custom_call->operand_shapes_with_layout()[i]));
784     }
785   }
786   return Status::OK();
787 }
788 
789 // For a while instruction, all the following layouts must be the same:
790 //   (1) init operand
791 //   (2) condition computation parameter
792 //   (3) body computation parameter
793 //   (4) body computation result
794 //   (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)795 Status CheckWhileLayout(HloInstruction* while_inst,
796                         const ComputationLayout& condition_computation_layout,
797                         const ComputationLayout& body_computation_layout) {
798   auto init_shape = while_inst->operand(0)->shape();
799   TF_RET_CHECK(
800       condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
801           init_shape, /*minor_to_major_only=*/true));
802   TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
803       init_shape, /*minor_to_major_only=*/true));
804   TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
805       init_shape, /*minor_to_major_only=*/true));
806   TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
807   return Status::OK();
808 }
809 
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)810 Status CheckConditionalLayout(
811     HloInstruction* instruction,
812     absl::Span<const ComputationLayout> branch_computation_layouts) {
813   for (int j = 0; j < instruction->branch_count(); ++j) {
814     const HloInstruction* branch_operand = instruction->operand(j + 1);
815     TF_RET_CHECK(
816         branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
817             branch_computation_layouts[j].result_layout().shape(),
818             /*minor_to_major_only=*/true));
819     TF_RET_CHECK(
820         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
821             instruction->shape(), /*minor_to_major_only=*/true));
822     TF_RET_CHECK(
823         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
824             instruction->branch_computation(j)->root_instruction()->shape(),
825             /*minor_to_major_only=*/true));
826     TF_RET_CHECK(
827         branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
828             branch_operand->shape(), /*minor_to_major_only=*/true));
829   }
830   return Status::OK();
831 }
832 
833 // Fusion parameters must match the layout of the fusion instructions operands,
834 // and the root of the fusion expression must match the layout of the fusion
835 // instruction.
CheckFusionLayout(HloInstruction * fusion)836 Status CheckFusionLayout(HloInstruction* fusion) {
837   TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
838 
839   TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
840                                     fusion->fused_expression_root()->shape()));
841   for (int64_t i = 0; i < fusion->operand_count(); ++i) {
842     TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
843                                       fusion->operand(i)->shape()));
844   }
845   return Status::OK();
846 }
847 
848 // The layout of a parameter must match the respective layout in the
849 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)850 Status CheckParameterLayout(HloInstruction* parameter,
851                             const ComputationLayout& computation_layout) {
852   const ShapeLayout& parameter_layout =
853       computation_layout.parameter_layout(parameter->parameter_number());
854   return ShapeUtil::ForEachSubshapeWithStatus(
855       parameter_layout.shape(),
856       [&](const Shape& subshape, const ShapeIndex& shape_index) {
857         if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
858             !subshape.has_layout()) {
859           return Status::OK();
860         }
861         if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
862                 subshape,
863                 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
864           return InternalError(
865               "parameter instruction %s does not match layout of computation "
866               "shape: %s",
867               parameter->ToString(), parameter_layout.ToString());
868         }
869         return Status::OK();
870       });
871 }
872 
873 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)874 Status CheckConstantLayout(HloInstruction* constant) {
875   if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
876     return InternalError(
877         "constant instruction %s does not match the layout of its literal %s",
878         constant->ToString(),
879         ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
880   }
881   return Status::OK();
882 }
883 
884 }  // namespace
885 
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)886 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
887     const Shape& shape_with_layout, HloInstruction* instruction) {
888   TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
889   DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
890       << ShapeUtil::HumanString(shape_with_layout) << " "
891       << ShapeUtil::HumanString(instruction->shape())
892       << " instruction: " << instruction->ToString();
893 
894   if (instruction->shape().IsTuple()) {
895     // Copy tuple elements which have differing layouts.
896     std::vector<HloInstruction*> element_copies;
897     for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
898          ++i) {
899       const Shape& target_shape =
900           ShapeUtil::GetSubshape(shape_with_layout, {i});
901       const Shape& instr_shape =
902           ShapeUtil::GetSubshape(instruction->shape(), {i});
903       HloInstruction* gte = instruction->parent()->AddInstruction(
904           HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
905 
906       if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
907                                                     instr_shape)) {
908         // Shapes and layouts are equal, no need to copy.
909         element_copies.push_back(gte);
910       } else {
911         SetupCopiedInstruction(*instruction, gte, {i});
912         // Recurse to copy each element.
913         TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
914                             CreateCopyWithNewLayout(target_shape, gte));
915         element_copies.push_back(element_copy);
916       }
917     }
918     // Gather element copies into a tuple with a new Tuple instruction.
919     HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
920         HloInstruction::CreateTuple(element_copies));
921     SetupCopiedInstruction(*instruction, tuple_copy, {});
922     LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
923     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
924         shape_with_layout, tuple_copy->mutable_shape()));
925     return tuple_copy;
926   } else if (instruction->shape().IsArray()) {
927     HloInstruction* copy =
928         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
929             instruction->shape(), HloOpcode::kCopy, instruction));
930     RegisterAddedCopy(copy);
931     SetupCopiedInstruction(*instruction, copy, {});
932     LayoutUtil::ClearLayout(copy->mutable_shape());
933     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
934         shape_with_layout, copy->mutable_shape()));
935 
936     return copy;
937   } else {
938     return FailedPrecondition(
939         "Can only copy array and tuple shaped instructions");
940   }
941 }
942 
943 // Creates a copy of the given operand if the operand's layout does not match
944 // the given layout. This copy replaces the use in the given instruction. Tuple
945 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64_t operand_no)946 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
947     const ShapeLayout& operand_layout, HloInstruction* instruction,
948     int64_t operand_no) {
949   HloInstruction* operand = instruction->mutable_operand(operand_no);
950   TF_RET_CHECK(operand_layout.LayoutIsSet());
951   TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
952 
953   if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
954                                                 operand->shape())) {
955     VLOG(2) << "Operand " << operand->ToString() << " layout matches in "
956             << instruction->ToString();
957     // Operand layout already matches our constraint. Nothing to do.
958     return Status::OK();
959   }
960   VLOG(2) << "Operand " << operand->ToString() << " layout does not match "
961           << operand_layout.ToString() << " in " << instruction->ToString();
962 
963   // If the operand is only used by a conditional, do the copy inside the branch
964   // to avoid overhead for other branches.
965   if (!reverse_computation_order_ &&
966       instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
967       instruction->operand(operand_no)->user_count() == 1) {
968     auto branch_comp = instruction->branch_computation(operand_no - 1);
969     auto param = branch_comp->parameter_instruction(0);
970     *param->mutable_shape() = operand->shape();
971     auto param_users = param->users();
972     TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
973                         CreateCopyWithNewLayout(operand_layout.shape(), param));
974     for (auto user : param_users) {
975       TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
976     }
977     VLOG(2) << "New copy of " << operand->ToString() << " is "
978             << param_copy->ToString();
979     if (param == branch_comp->root_instruction()) {
980       branch_comp->set_root_instruction(param_copy,
981                                         /*accept_different_shape=*/true);
982     }
983     *FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) =
984         ShapeLayout(operand->shape());
985     return Status::OK();
986   }
987 
988   TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
989                       CreateCopyWithNewLayout(operand_layout.shape(), operand));
990 
991   VLOG(4) << "New copy of " << operand->ToString() << " is "
992           << operand_copy->ToString();
993   return instruction->ReplaceOperandWith(operand_no, operand_copy);
994 }
995 
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)996 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
997                                               HloInstruction* copy,
998                                               const ShapeIndex& index) {
999   if (instruction.has_sharding()) {
1000     // If the index is empty, we want to copy the whole sharding, in case the
1001     // sharding is a tuple sharding.
1002     HloSharding sharding =
1003         !index.empty() && instruction.sharding().IsTuple()
1004             ? instruction.sharding().GetSubSharding(instruction.shape(), index)
1005             : instruction.sharding();
1006     // We propagate the sharding to the copied instruction only if it is a
1007     // special sharding, like tiled ones.
1008     // Otherwise it is preferable to leave the new instruction without device,
1009     // and let the automatic device placer to choose the best location.
1010     auto device = sharding.UniqueDevice();
1011     if (!device || HloSharding::IsReservedDevice(*device)) {
1012       copy->set_sharding(sharding);
1013     }
1014   }
1015   copy->set_metadata(instruction.metadata());
1016 }
1017 
CheckLayouts(HloModule * module)1018 Status LayoutAssignment::CheckLayouts(HloModule* module) {
1019   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
1020                       TuplePointsToAnalysis::Run(module));
1021   for (auto* computation : module->MakeNonfusionComputations()) {
1022     for (auto* instruction : computation->instructions()) {
1023       // Verify every instruction has a layout and the layout is valid for the
1024       // shape.
1025       TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1026       TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
1027 
1028       // Use points-to analysis to verify that every subshape element in the
1029       // output of the instruction matches the layout of the logical buffer
1030       // which could be the source of the subshape value.
1031       const PointsToSet& points_to_set =
1032           points_to_analysis->GetPointsToSet(instruction);
1033       TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
1034           [&instruction](ShapeIndex index,
1035                          const PointsToSet::BufferList& buffers) -> Status {
1036             if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
1037               const Shape& instruction_subshape =
1038                   ShapeUtil::GetSubshape(instruction->shape(), index);
1039               for (const LogicalBuffer* buffer : buffers) {
1040                 if (!Shape::Equal()
1041                          .IgnoreDynamicDimension()
1042                          .MinorToMajorOnlyInLayout()(instruction_subshape,
1043                                                      buffer->shape())) {
1044                   return InternalError(
1045                       "Layout of instruction %s at index {%s} does not match "
1046                       "source LogicalBuffer %s: %s vs %s",
1047                       instruction->name(), absl::StrJoin(index, ","),
1048                       buffer->ToString(),
1049                       ShapeUtil::HumanStringWithLayout(instruction_subshape),
1050                       ShapeUtil::HumanStringWithLayout(buffer->shape()));
1051                 }
1052               }
1053             }
1054             return Status::OK();
1055           }));
1056 
1057       // Verify instructions that have special layout constraints.
1058       switch (instruction->opcode()) {
1059         case HloOpcode::kCall:
1060           TF_RETURN_IF_ERROR(CheckCallLayout(
1061               instruction,
1062               FindOrDie(computation_layouts_, instruction->to_apply())));
1063           break;
1064         case HloOpcode::kCustomCall:
1065           TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
1066           break;
1067         case HloOpcode::kFusion:
1068           TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
1069           break;
1070         case HloOpcode::kParameter:
1071           TF_RETURN_IF_ERROR(CheckParameterLayout(
1072               instruction,
1073               FindOrDie(computation_layouts_, instruction->parent())));
1074           break;
1075         case HloOpcode::kConstant:
1076           TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
1077           break;
1078         case HloOpcode::kWhile:
1079           TF_RETURN_IF_ERROR(CheckWhileLayout(
1080               instruction,
1081               FindOrDie(computation_layouts_, instruction->while_condition()),
1082               FindOrDie(computation_layouts_, instruction->while_body())));
1083           break;
1084         case HloOpcode::kConditional: {
1085           std::vector<ComputationLayout> branch_computation_layouts;
1086           for (auto branch_computation : instruction->branch_computations()) {
1087             branch_computation_layouts.emplace_back(
1088                 FindOrDie(computation_layouts_, branch_computation));
1089           }
1090           TF_RETURN_IF_ERROR(CheckConditionalLayout(
1091               instruction, absl::MakeSpan(branch_computation_layouts)));
1092           break;
1093         }
1094         default:
1095           break;
1096       }
1097     }
1098   }
1099   // Finally verify the result layout, if set, matches the layout of the entry
1100   // computation root.
1101   const ShapeLayout& result_layout =
1102       FindOrDie(computation_layouts_, module->entry_computation())
1103           .result_layout();
1104   if (result_layout.LayoutIsSet()) {
1105     TF_RET_CHECK(
1106         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1107             module->result_shape(), result_layout.shape()));
1108   }
1109   return Status::OK();
1110 }
1111 
LayoutAssignment(ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints,bool reverse_computation_order)1112 LayoutAssignment::LayoutAssignment(
1113     ComputationLayout* entry_computation_layout,
1114     ChannelLayoutConstraints* channel_constraints,
1115     bool reverse_computation_order)
1116     : entry_computation_layout_(entry_computation_layout),
1117       saved_entry_computation_layout_(*entry_computation_layout),
1118       reverse_computation_order_(reverse_computation_order),
1119       channel_layout_constraints_(channel_constraints) {
1120   if (channel_layout_constraints_ != nullptr) {
1121     // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1122     // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1123     channel_constraints_ = *channel_layout_constraints_;
1124   }
1125   VLOG(1) << "Entry computation layout given to layout assignment: "
1126           << entry_computation_layout_->ToString();
1127 }
1128 
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64_t operand_no)1129 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1130     const Layout& output_layout, const HloInstruction* instruction,
1131     int64_t operand_no) {
1132   const HloInstruction* operand = instruction->operand(operand_no);
1133   CHECK(instruction->shape().IsArray());
1134   CHECK(operand->shape().IsArray());
1135   if (!ShapeUtil::IsScalar(operand->shape()) &&
1136       operand->shape().rank() == instruction->shape().rank() &&
1137       !InstructionCanChangeLayoutInstance(instruction)) {
1138     // Propagate the result layout to the operand layout if the instruction
1139     // requires the same layout out for the result and the operand.
1140     //
1141     // For elementwise operations, using the same layout for the operands and
1142     // the result also has the following benefits:
1143     // 1) the elementwise operation can reuse its operand's buffer, and
1144     // 2) the input and output elements can reuse the same linear index.
1145     return absl::make_unique<Layout>(output_layout);
1146   }
1147 
1148   if (instruction->opcode() == HloOpcode::kReshape) {
1149     // Prefer the operand layout that makes the reshape an bitcast. If any
1150     // dimension bound is 1 in the operand shape, there may be several such
1151     // layouts. So if 'output_layout' is the default layout, try if the
1152     // reshape is a bitcast when using the same layout. This may avoid copy
1153     // operations. For similar reasons, if the operand and output have the same
1154     // rank, try to match the operand's layout to the output.
1155     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1156         ShapeUtil::TrueRank(instruction->shape()) == 1) {
1157       // Don't assign a layout in case of R1 -> effective R1 reshape.
1158       return nullptr;
1159     }
1160 
1161     const Shape& output_shape = instruction->shape();
1162     Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1163         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1164         LayoutUtil::MinorToMajor(output_layout));
1165     Shape operand_shape = operand->shape();
1166     *operand_shape.mutable_layout() =
1167         LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1168     auto aligned_operand_shape =
1169         ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1170     if (aligned_operand_shape) {
1171       auto operand_layout = aligned_operand_shape.value().layout();
1172       TF_CHECK_OK(
1173           LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1174       return absl::make_unique<Layout>(operand_layout);
1175     }
1176   }
1177 
1178   if (instruction->opcode() == HloOpcode::kTranspose) {
1179     // Pick the operand layout that makes the transpose a bitcast.
1180     int64_t rank = instruction->shape().rank();
1181     std::vector<int64> new_minor_to_major(rank);
1182     for (int64_t i = 0; i < rank; ++i) {
1183       int64_t output_dim = LayoutUtil::Minor(output_layout, i);
1184       int64_t operand_dim = instruction->dimensions(output_dim);
1185       new_minor_to_major[i] = operand_dim;
1186     }
1187     Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1188     TF_CHECK_OK(
1189         LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1190     return absl::make_unique<Layout>(operand_layout);
1191   }
1192 
1193   return nullptr;
1194 }
1195 
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64_t operand_no)1196 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1197     const Layout& operand_layout, const HloInstruction* user,
1198     int64_t operand_no) {
1199   const HloInstruction* operand = user->operand(operand_no);
1200 
1201   CHECK(user->shape().IsArray() && operand->shape().IsArray());
1202 
1203   if (!ShapeUtil::IsScalar(operand->shape()) &&
1204       operand->shape().rank() == user->shape().rank() &&
1205       !InstructionCanChangeLayoutInstance(user)) {
1206     // Assign users the same layout as the operand.
1207     return absl::make_unique<Layout>(operand_layout);
1208   }
1209 
1210   if (user->opcode() == HloOpcode::kReshape) {
1211     // Prefer the user layout that makes the reshape an bitcast. If any
1212     // dimension bound is 1 in the user shape, there may be several such
1213     // layouts. So if 'operand_layout' is the default layout, try if the
1214     // reshape is a bitcast when using the same layout. This may avoid copy
1215     // operations. For similar reasons, if the operand and output have the same
1216     // rank, try to match the outputs's layout to the operand.
1217     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1218         ShapeUtil::TrueRank(user->shape()) == 1) {
1219       // Don't assign a layout in case of R1 -> effective R1 reshape.
1220       return nullptr;
1221     }
1222     Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1223         operand->shape().element_type(),
1224         AsInt64Slice(operand->shape().dimensions()),
1225         LayoutUtil::MinorToMajor(operand_layout));
1226     Shape output_shape = user->shape();
1227     *output_shape.mutable_layout() =
1228         LayoutUtil::GetDefaultLayoutForShape(output_shape);
1229     auto aligned_user_shape =
1230         ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1231     if (aligned_user_shape) {
1232       auto user_layout = aligned_user_shape.value().layout();
1233       TF_CHECK_OK(
1234           LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1235       return absl::make_unique<Layout>(user_layout);
1236     }
1237   }
1238 
1239   if (user->opcode() == HloOpcode::kTranspose) {
1240     // Pick the user layout that makes the transpose a bitcast.
1241     int64_t rank = user->shape().rank();
1242     std::vector<int64> new_minor_to_major(rank);
1243     auto inverse_dimensions = InversePermutation(user->dimensions());
1244     for (int64_t i = 0; i < rank; ++i) {
1245       int64_t operand_dim = LayoutUtil::Minor(operand_layout, i);
1246       int64_t user_dim = inverse_dimensions[operand_dim];
1247       new_minor_to_major[i] = user_dim;
1248     }
1249     Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1250     TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1251     return absl::make_unique<Layout>(user_layout);
1252   }
1253 
1254   return nullptr;
1255 }
1256 
PropagateConstraints(LayoutConstraints * constraints)1257 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1258   // Gathers all initial constraints in a worklist and propagates them in
1259   // depth-first order. DFS order seems to be better than BFS because a
1260   // constraint is propagated as far as possible before propagating unrelated
1261   // constraints which makes it less likely that conflicting constraints will be
1262   // propagated to instructions. However, we should experiment with other orders
1263   // too.
1264   std::deque<const LayoutConstraint*> worklist;
1265 
1266   // Lambda for moving newly added constraints to the worklist.
1267   auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1268     // Add constraints to the front of the deque for DFS ordering.
1269     for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1270       if (constraint->dfs()) {
1271         worklist.push_front(constraint);
1272       } else {
1273         VLOG(3) << "push back constraint for propagation : "
1274                 << constraint->ToString();
1275         worklist.push_back(constraint);
1276       }
1277     }
1278   };
1279   add_new_constraints_to_worklist();
1280 
1281   while (!worklist.empty()) {
1282     const LayoutConstraint* layout_constraint = worklist.front();
1283     worklist.pop_front();
1284     VLOG(2) << "Propagating " << layout_constraint->ToString()
1285             << " to its neighbors with priority = "
1286             << layout_constraint->priority() << "\n";
1287     if (auto* buffer_constraint =
1288             dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1289       TF_RETURN_IF_ERROR(
1290           PropagateBufferConstraint(*buffer_constraint, constraints));
1291     } else if (auto* operand_constraint =
1292                    dynamic_cast<const OperandLayoutConstraint*>(
1293                        layout_constraint)) {
1294       TF_RETURN_IF_ERROR(
1295           PropagateOperandConstraint(*operand_constraint, constraints));
1296     } else if (auto* result_constraint =
1297                    dynamic_cast<const ResultLayoutConstraint*>(
1298                        layout_constraint)) {
1299       TF_RETURN_IF_ERROR(
1300           PropagateResultConstraint(*result_constraint, constraints));
1301     } else {
1302       LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1303     }
1304 
1305     add_new_constraints_to_worklist();
1306   }
1307   return Status::OK();
1308 }
1309 
1310 namespace {
1311 
1312 // Returns a vector containing all array-shaped uses (instruction and operand
1313 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1314 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1315     const LogicalBuffer& buffer,
1316     const TuplePointsToAnalysis& points_to_analysis) {
1317   CHECK(buffer.IsArray());
1318   std::vector<std::pair<const HloInstruction*, int64>> uses;
1319   for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1320     if (!buffer_alias.instruction()->shape().IsArray()) {
1321       continue;
1322     }
1323     // This alias must be the top-level (index == {}) of the instruction's
1324     // result because the instruction produces an array.
1325     CHECK(buffer_alias.index().empty());
1326 
1327     // Add all uses of the instruction's output.
1328     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1329       for (int64_t operand_no :
1330            user->OperandIndices(buffer_alias.instruction())) {
1331         uses.emplace_back(user, operand_no);
1332       }
1333     }
1334   }
1335   return uses;
1336 }
1337 
1338 }  // namespace
1339 
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1340 Status LayoutAssignment::PropagateUseConstraintToDefs(
1341     const ShapeLayout& shape_layout, const HloInstruction* instruction,
1342     LayoutConstraints* constraints) {
1343   // Try to set all logical buffers which may be sources of the given operand to
1344   // match the given layout.
1345   const PointsToSet& points_to_set =
1346       constraints->points_to_analysis().GetPointsToSet(instruction);
1347   return points_to_set.ForEachElementWithStatus(
1348       [&shape_layout, constraints](
1349           const ShapeIndex& index,
1350           const PointsToSet::BufferList& buffers) -> Status {
1351         if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1352           for (const LogicalBuffer* buffer : buffers) {
1353             if (constraints->BufferLayout(*buffer) == nullptr &&
1354                 buffer->shape().IsArray()) {
1355               TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1356                   ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1357                   *buffer, /*mandatory=*/true));
1358             }
1359           }
1360         }
1361         return Status::OK();
1362       });
1363 }
1364 
1365 namespace {
1366 // A transpose or a reshape that only changes trivial dimensions have meaningful
1367 // layouts that are valuable to propagate in a depthfirst manner to avoid
1368 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo,bool forward_propagation=true)1369 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
1370                                           bool forward_propagation = true) {
1371   switch (hlo.opcode()) {
1372     case HloOpcode::kFusion:
1373       return hlo.IsCustomFusion();
1374     case HloOpcode::kGather:
1375       return true;
1376     case HloOpcode::kReshape:
1377       return hlo.operand(0)->shape().rank() == 1 ||
1378              (forward_propagation &&
1379               std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
1380     case HloOpcode::kScatter:
1381     case HloOpcode::kTranspose:
1382       return true;
1383     default:
1384       return false;
1385   }
1386 }
1387 
1388 }  // namespace
1389 
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1390 Status LayoutAssignment::PropagateOperandConstraint(
1391     const OperandLayoutConstraint& operand_constraint,
1392     LayoutConstraints* constraints) {
1393   // Try to set the layout of the logical buffers in the given operand to match
1394   // the constrained layout. This avoids copies.
1395   TF_RETURN_IF_ERROR(
1396       PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1397                                    operand_constraint.operand(), constraints));
1398 
1399   // For array-shaped operands and user instructions try to pick a minimum cost
1400   // layout. For example, if the operand of an elementwise instruction is
1401   // constrained to a certain layout we want the output of the instruction to
1402   // have the same layout.
1403   //
1404   // If the user is not array-shaped, we still want to propagate the layout
1405   // to siblings if the instruction can't change layout. This is to represent
1406   // the information that non-layout-changing instructions should have the same
1407   // layout for the operands with the same ranks.
1408   const HloInstruction* operand = operand_constraint.operand();
1409   const HloInstruction* user = operand_constraint.instruction();
1410   if (!operand->shape().IsArray()) {
1411     return Status::OK();
1412   }
1413 
1414   if (user->opcode() == HloOpcode::kAllReduce) {
1415     const auto shape_index =
1416         user->operand_count() == 1
1417             ? ShapeIndex()
1418             : ShapeIndex({operand_constraint.operand_no()});
1419     TF_ASSIGN_OR_RETURN(const LogicalBuffer* buffer,
1420                         constraints->points_to_analysis().GetBufferDefinedAt(
1421                             user, shape_index));
1422     const BufferLayoutConstraint* constraint =
1423         constraints->GetBufferLayoutConstraint(*buffer);
1424     if (constraint == nullptr) {
1425       TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1426           operand_constraint.shape_layout().layout(), *buffer,
1427           /*mandatory=*/false));
1428     }
1429   }
1430   if (InstructionCanChangeLayoutInstance(user) && !user->shape().IsArray()) {
1431     return Status::OK();
1432   }
1433 
1434   // Only try to choose a low cost layout if the instruction 'user' defines its
1435   // output (ie, doesn't forward a buffer from elsewhere).
1436   if (constraints->AnyOperandBufferForwarded(user,
1437                                              operand_constraint.operand_no())) {
1438     return Status::OK();
1439   }
1440 
1441   int64_t operand_rank = operand->shape().rank();
1442   if (operand_rank <= 1) {
1443     return Status::OK();
1444   }
1445 
1446   // Propagate layouts between operands of the same instruction. This is a
1447   // constraint on non-layout-changing instructions.
1448   if (!InstructionCanChangeLayoutInstance(user)) {
1449     // Only propgate the layout of the largest concatenate operand.
1450     if (user->opcode() == HloOpcode::kConcatenate) {
1451       for (int64_t operand_no = 0; operand_no < user->operand_count();
1452            ++operand_no) {
1453         const HloInstruction* sibling = user->operand(operand_no);
1454         if (sibling == operand) {
1455           continue;
1456         }
1457         if (sibling->shape().dimensions(user->concatenate_dimension()) >
1458             operand->shape().dimensions(user->concatenate_dimension())) {
1459           return Status::OK();
1460         }
1461       }
1462     }
1463     // Make sure all siblings have the same layout as the operand.
1464     for (int64_t operand_no = 0; operand_no < user->operand_count();
1465          ++operand_no) {
1466       if (user->operand(operand_no) == operand) {
1467         continue;
1468       }
1469       const HloInstruction* sibling = user->operand(operand_no);
1470       const int64_t sibling_rank = sibling->shape().rank();
1471       if (sibling_rank <= 1) {
1472         continue;
1473       }
1474       if (operand_rank != sibling_rank) {
1475         continue;
1476       }
1477       const OperandLayoutConstraint* constraint =
1478           constraints->GetOperandLayoutConstraint(user, operand_no);
1479       if (constraint != nullptr) {
1480         // Due to the DFS of the propagation we can end up here when operand_no
1481         // has a layout set that hasn't been propagated yet (is still on the
1482         // stack of layouts to propagate).
1483         // We can continue here and leave the operands with different layouts,
1484         // as we will either:
1485         // - overwrite the current operand when the DFS gets back to propagating
1486         //   operand(operand_no) to its siblings
1487         // - overwrite operand(operand_no)'s layout with a mandatory layout if
1488         //   we continue to propagate our layout to the result, and then
1489         //   backwards into all operands (if the result is an array of rank > 1)
1490         continue;
1491       }
1492       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1493           operand_constraint.shape_layout().layout(), user, operand_no,
1494           /*mandatory=*/false));
1495     }
1496     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1497         user->shape(),
1498         [&](const Shape& subshape, const ShapeIndex& shape_index) {
1499           if (subshape.IsTuple()) {
1500             return Status::OK();
1501           }
1502           if (subshape.rank() <= 1) {
1503             return Status::OK();
1504           }
1505 
1506           // Assign the right layout to input fusion of higher rank reduce
1507           // operations.
1508           if (subshape.rank() != operand->shape().rank()) {
1509             return Status::OK();
1510           }
1511           if (!constraints->points_to_analysis()
1512                    .InstructionDefinesBufferAtIndex(user, shape_index)) {
1513             return Status::OK();
1514           }
1515           // TODO(b/67641796): Are there cases except fusion that use this code
1516           // path?
1517           TF_ASSIGN_OR_RETURN(
1518               const LogicalBuffer* buffer,
1519               constraints->points_to_analysis().GetBufferDefinedAt(
1520                   user, shape_index));
1521           // Make sure the output has the same layout as the operand.
1522           const BufferLayoutConstraint* constraint =
1523               constraints->GetBufferLayoutConstraint(*buffer);
1524           // If we already have a constraint for the buffer it was assigned but
1525           // hasn't propagated yet. This can happen with diamond-shaped graphs
1526           // where one path is first evaluated in depth-first order (we're here)
1527           // and the other path is propagated later. We don't set the layout
1528           // here as it will always be overwritten later.
1529           if (constraint == nullptr) {
1530             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1531                 operand_constraint.shape_layout().layout(), *buffer,
1532                 /*mandatory=*/false));
1533           }
1534           return Status::OK();
1535         }));
1536     return Status::OK();
1537   }
1538   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1539       user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1540         if (subshape.IsTuple()) {
1541           return Status::OK();
1542         }
1543         if (subshape.rank() <= 1) {
1544           return Status::OK();
1545         }
1546         if (!constraints->points_to_analysis().InstructionDefinesBufferAtIndex(
1547                 user, shape_index)) {
1548           return Status::OK();
1549         }
1550         TF_ASSIGN_OR_RETURN(
1551             const LogicalBuffer* buffer,
1552             constraints->points_to_analysis().GetBufferDefinedAt(user,
1553                                                                  shape_index));
1554         if (constraints->BufferLayout(*buffer) == nullptr ||
1555             !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1556           std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1557               operand_constraint.shape_layout().layout(), user,
1558               operand_constraint.operand_no());
1559           if (layout != nullptr) {
1560             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1561                 *layout, *buffer,
1562                 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1563                 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1564           }
1565         }
1566         return Status::OK();
1567       }));
1568   return Status::OK();
1569 }
1570 
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1571 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1572     const BufferLayoutConstraint& buffer_constraint,
1573     LayoutConstraints* constraints) {
1574   VLOG(5) << "PropagateBufferConstraintToOperands: "
1575           << buffer_constraint.ToString();
1576   const LogicalBuffer& buffer = buffer_constraint.buffer();
1577 
1578   const HloInstruction* instruction = buffer.instruction();
1579   if (IsAtMostRank1(instruction->shape())) {
1580     return Status::OK();
1581   }
1582 
1583   if (instruction->opcode() == HloOpcode::kAllReduce) {
1584     TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1585         buffer_constraint.layout(), instruction,
1586         instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1587         /*mandatory=*/true));
1588     return Status::OK();
1589   }
1590   for (int64_t operand_no = 0; operand_no < instruction->operand_count();
1591        ++operand_no) {
1592     const HloInstruction* operand = instruction->operand(operand_no);
1593     if (IsAtMostRank1(operand->shape())) {
1594       continue;
1595     }
1596     if (!InstructionCanChangeLayoutInstance(instruction)) {
1597       // Copy the layout to the operand.
1598       if (buffer.IsArray() && operand->shape().IsArray() &&
1599           operand->shape().rank() ==
1600               LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1601         TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1602             buffer_constraint.layout(), instruction, operand_no,
1603             /*mandatory=*/true));
1604       }
1605     } else {
1606       if (!buffer.IsTopLevel() ||
1607           !instruction->operand(operand_no)->shape().IsArray()) {
1608         continue;  // Don't touch buffers that are internal to a tuple.
1609       }
1610       VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1611               << instruction->ToShortString();
1612       // Assign a layout if there is no constraint already.
1613       const OperandLayoutConstraint* constraint =
1614           constraints->GetOperandLayoutConstraint(instruction, operand_no);
1615       if (constraint == nullptr || !constraint->mandatory()) {
1616         std::unique_ptr<Layout> operand_layout =
1617             ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1618                                                 instruction, operand_no);
1619         if (operand_layout != nullptr) {
1620           TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1621               *operand_layout, instruction, operand_no, /*mandatory=*/false,
1622               /*dfs=*/
1623               InstructionShouldPropagateDepthFirst(
1624                   *instruction, /*forward_propagation=*/false)));
1625         }
1626       } else {
1627         VLOG(6) << "Operand already has a constraint "
1628                 << constraint->ToString();
1629       }
1630     }
1631   }
1632   return Status::OK();
1633 }
1634 
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1635 Status LayoutAssignment::PropagateBufferConstraint(
1636     const BufferLayoutConstraint& buffer_constraint,
1637     LayoutConstraints* constraints) {
1638   // Only propagate array layouts.
1639   const LogicalBuffer& buffer = buffer_constraint.buffer();
1640   if (!buffer.IsArray()) {
1641     return Status::OK();
1642   }
1643   TF_RETURN_IF_ERROR(
1644       PropagateBufferConstraintToOperands(buffer_constraint, constraints));
1645   return PropagateBufferConstraintToUses(buffer_constraint, constraints);
1646 }
1647 
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1648 Status LayoutAssignment::PropagateBufferConstraintToUses(
1649     const BufferLayoutConstraint& buffer_constraint,
1650     LayoutConstraints* constraints) {
1651   const LogicalBuffer& buffer = buffer_constraint.buffer();
1652   TF_RET_CHECK(buffer.IsArray());
1653 
1654   // Propagate the layout to all array uses of the logical buffer. This skips
1655   // uses of the buffer where the buffer is the element of a tuple.
1656   for (const auto& user_operand_no :
1657        GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1658     const HloInstruction* user = user_operand_no.first;
1659     int64_t operand_no = user_operand_no.second;
1660     // Only add an operand constraint if the user does not forward the buffer
1661     // because this case is not handled is SetOperandLayout.
1662     if (constraints->OperandLayout(user, operand_no) == nullptr &&
1663         !constraints->AnyOperandBufferForwarded(user, operand_no)) {
1664       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1665           buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1666     }
1667   }
1668 
1669   // Propagate to backedges of kWhile.
1670   CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1671   if (node.caller_callsites().size() != 1) {
1672     return Status::OK();
1673   }
1674   const HloInstruction* parent = node.caller_callsites()[0].instruction();
1675   if (parent->opcode() != HloOpcode::kWhile) {
1676     return Status::OK();
1677   }
1678 
1679   for (HloInstruction* user : buffer.instruction()->users()) {
1680     if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1681       continue;
1682     }
1683     if (user->parent()->root_instruction() == user) {
1684       VLOG(3) << "Propagating layout through backedge"
1685               << buffer_constraint.layout().ToString();
1686       int64_t index = user->operand_index(buffer.instruction());
1687       TF_ASSIGN_OR_RETURN(
1688           auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
1689                            user->parent()->parameter_instruction(0), {index}));
1690 
1691       TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1692           buffer_constraint.layout(), *buffer, /*mandatory=*/false));
1693     }
1694   }
1695 
1696   return Status::OK();
1697 }
1698 
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1699 Status LayoutAssignment::PropagateResultConstraint(
1700     const ResultLayoutConstraint& layout_constraint,
1701     LayoutConstraints* constraints) {
1702   // Propagate the use constraint of the root instruction up to the logical
1703   // buffers which make up the result.
1704   return PropagateUseConstraintToDefs(
1705       layout_constraint.shape_layout(),
1706       constraints->computation()->root_instruction(), constraints);
1707 }
1708 
1709 namespace {
1710 
1711 // Infers the layout of the array at the given index in the given instruction's
1712 // output using points-to analysis. Precondition: The given instruction must
1713 // not produce this array value (that is, the array is forwarded from the
1714 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1715 StatusOr<Layout> InferArrayLayout(
1716     const TuplePointsToAnalysis& points_to_analysis,
1717     HloInstruction* instruction, const ShapeIndex& index) {
1718   // This function should only be called for array shapes which don't yet have
1719   // layouts.
1720   const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1721   TF_RET_CHECK(subshape.IsArray());
1722   TF_RET_CHECK(!subshape.has_layout());
1723 
1724   // The instruction should not define the buffer at this index.
1725   TF_RET_CHECK(
1726       !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1727       << instruction->ToString();
1728 
1729   const auto& source_buffers =
1730       points_to_analysis.GetPointsToSet(instruction).element(index);
1731   TF_RET_CHECK(!source_buffers.empty());
1732 
1733   // Verify the layout is the same for every LogicalBuffer which this location
1734   // ('instruction' and 'index') points to.
1735   const Layout* first_buffer_layout = nullptr;
1736   for (const LogicalBuffer* source_buffer : source_buffers) {
1737     if (!source_buffer->shape().has_layout()) {
1738       // This should not happen because we've assigned layouts to all
1739       // instructions preceding this one.
1740       return InternalError("LogicalBuffer %s does not have a layout",
1741                            source_buffer->ToString());
1742     }
1743 
1744     if (first_buffer_layout == nullptr) {
1745       first_buffer_layout = &source_buffer->shape().layout();
1746     } else if (!Layout::Equal().MinorToMajorOnly()(
1747                    source_buffer->shape().layout(), *first_buffer_layout)) {
1748       // The points-to set is ambiguous for this index and the different source
1749       // buffers have different layouts. This case is possible in valid XLA
1750       // computations because we do not propagate BufferLayoutConstraints to all
1751       // LogicalBuffers which may alias the constrained LogicalBuffer at some
1752       // point in the computation.
1753       return FailedPrecondition(
1754           "Array at index {%s} in instruction %s aliases buffers %s "
1755           "and %s which have different layouts",
1756           absl::StrJoin(index, ","), instruction->name(),
1757           source_buffers[0]->ToString(), source_buffer->ToString());
1758     }
1759   }
1760 
1761   return *first_buffer_layout;
1762 }
1763 
1764 // For fusion instructions, set the layout of each fused parameter instruction
1765 // to match the layout of its corresponding fusion instruction operand. Also,
1766 // set the layout of the fused root to match the layout of the fusion
1767 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1768 Status SetFusionLayouts(HloInstruction* fusion) {
1769   TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1770   for (auto* fused_instruction :
1771        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1772     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1773       const HloInstruction* fusion_operand =
1774           fusion->operand(fused_instruction->parameter_number());
1775       DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1776                                    fused_instruction->shape()));
1777       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1778           fusion_operand->shape(), fused_instruction->mutable_shape()));
1779     } else if (fused_instruction == fusion->fused_expression_root()) {
1780       // The layout of the root of the fused expression must match the fusion
1781       // instruction layout.
1782       DCHECK(
1783           ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1784       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1785           fusion->shape(), fused_instruction->mutable_shape()));
1786     } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1787       // A GTE inherits its layout from its operand (which should ultimately be
1788       // a parameter).
1789       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1790           fused_instruction->operand(0)->shape().tuple_shapes(
1791               fused_instruction->tuple_index()),
1792           fused_instruction->mutable_shape()));
1793     } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1794       // Give constants the layout of their literal.
1795       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1796           fused_instruction->literal().shape(),
1797           fused_instruction->mutable_shape()));
1798     } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1799       // Nop; leave the infeed layout alone.
1800     } else if (!fusion->IsCustomFusion()) {
1801       // Other instructions don't have layouts inside of fusion nodes.
1802       // But do not clear layouts for other instructions in custom fusion nodes.
1803       LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1804     }
1805   }
1806 
1807   return Status::OK();
1808 }
1809 
1810 }  // namespace
1811 
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1812 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1813                                        HloComputation* computation) {
1814   VLOG(2) << "Assigning layouts to computation: " << computation->name();
1815   XLA_VLOG_LINES(2, computation->ToString());
1816   XLA_VLOG_LINES(2, constraints.ToString());
1817 
1818   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1819     LayoutUtil::ClearLayout(instruction->mutable_shape());
1820 
1821     // Set the layouts of the array shapes this instruction defines as indicated
1822     // by the respective BufferLayoutConstraints. Any array shapes in the output
1823     // of the instruction which are not defined by the instruction (eg, array
1824     // elements in a Tuple instruction) will be assigned below via inference.
1825     for (const LogicalBuffer* buffer :
1826          constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1827              instruction)) {
1828       if (!buffer->shape().IsArray()) {
1829         continue;
1830       }
1831 
1832       TF_RET_CHECK(buffer->instruction() == instruction);
1833       const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1834       TF_RET_CHECK(buffer_layout != nullptr);
1835 
1836       if (instruction->opcode() == HloOpcode::kConstant) {
1837         // For constants, we also need to change the layout of the internal
1838         // literal.
1839         instruction->RelayoutConstant(*buffer_layout, buffer->index());
1840       } else {
1841         Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1842             instruction->mutable_shape(), buffer->index());
1843         *buffer_subshape->mutable_layout() = *buffer_layout;
1844       }
1845     }
1846 
1847     // Any remaining layouts in the output of the instruction must be
1848     // inferrable using points-to analysis.
1849     TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1850         instruction->mutable_shape(),
1851         [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1852           if (subshape->has_layout() || !subshape->IsArray()) {
1853             return Status::OK();
1854           }
1855           // Set Layout of subshape to match layout of LogicalBuffer which
1856           // produces it.
1857           TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1858                               InferArrayLayout(constraints.points_to_analysis(),
1859                                                instruction, index));
1860           return Status::OK();
1861         }));
1862     // Create a copy of an operand if the operand instruction's layout does not
1863     // match the use constraint (OperandLayoutConstraint).
1864     for (int64_t operand_no = 0; operand_no < instruction->operand_count();
1865          ++operand_no) {
1866       const ShapeLayout* operand_layout =
1867           constraints.OperandLayout(instruction, operand_no);
1868       if (operand_layout != nullptr) {
1869         TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1870                                                       instruction, operand_no));
1871       } else {
1872         VLOG(2) << "operand " << operand_no << " has no constraint";
1873       }
1874     }
1875 
1876     // Process instructions that contain nested computations and may require
1877     // additional layouts to be assigned on the instructions nested inside.
1878     switch (instruction->opcode()) {
1879       case HloOpcode::kFusion:
1880         TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1881         break;
1882       case HloOpcode::kCall:
1883         if (reverse_computation_order_) {
1884           ProgramShape program_shape;
1885           for (int64_t i = 0; i < instruction->operand_count(); ++i) {
1886             *program_shape.add_parameters() = instruction->operand(i)->shape();
1887           }
1888           *program_shape.mutable_result() = instruction->shape();
1889           computation_layouts_.emplace(
1890               instruction->to_apply(),
1891               ComputationLayout(program_shape,
1892                                 /*ignore layout=*/false));
1893         }
1894         break;
1895       case HloOpcode::kConditional:
1896         // If the branches don't yet have layouts, propagate existing layout
1897         // inside the branches.
1898         if (reverse_computation_order_) {
1899           for (int i = 0; i < instruction->branch_count(); ++i) {
1900             ProgramShape program_shape;
1901             auto* branch_operand = instruction->operand(i + 1);
1902             *program_shape.add_parameters() = branch_operand->shape();
1903             *program_shape.mutable_result() = instruction->shape();
1904             computation_layouts_.emplace(
1905                 instruction->branch_computation(i),
1906                 ComputationLayout(program_shape,
1907                                   /*ignore layout=*/false));
1908           }
1909         }
1910         break;
1911       case HloOpcode::kWhile:
1912         // If the loop body doesn't have layouts, propagate existing one inside.
1913         if (reverse_computation_order_) {
1914           VLOG(2) << "Populating while loop constraints inside loop body.";
1915           VLOG(2) << instruction->ToString();
1916           auto init_shape = instruction->operand(0)->shape();
1917           ProgramShape body_shape;
1918           *body_shape.add_parameters() = init_shape;
1919           *body_shape.mutable_result() = init_shape;
1920           computation_layouts_.emplace(
1921               instruction->while_body(),
1922               ComputationLayout(body_shape,
1923                                 /*ignore layout=*/false));
1924           VLOG(2) << "Populating while loop constraints inside loop condition.";
1925           VLOG(2) << instruction->ToString();
1926           ProgramShape condition_shape;
1927           *condition_shape.add_parameters() = init_shape;
1928           *condition_shape.mutable_result() = ShapeUtil::MakeScalarShape(PRED);
1929         }
1930         break;
1931       default:
1932         break;
1933     }
1934     // Execute extra verification step once the layout has been finalized.
1935     TF_RETURN_IF_ERROR(Verify(instruction));
1936 
1937     // Shape must be valid.
1938     TF_RETURN_IF_ERROR(
1939         ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1940 
1941     // Verify all layouts in the shape have been set.
1942     TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1943   }
1944   return Status::OK();
1945 }
1946 
CalculateComputationLayout(HloComputation * computation)1947 Status LayoutAssignment::CalculateComputationLayout(
1948     HloComputation* computation) {
1949   if (!reverse_computation_order_ ||
1950       computation_layouts_.find(computation) == computation_layouts_.end()) {
1951     ComputationLayout computation_layout(computation->ComputeProgramShape(),
1952                                          /*ignore_layouts=*/false);
1953     InsertOrDie(&computation_layouts_, computation, computation_layout);
1954     VLOG(2) << "  Calculated ComputationLayout = "
1955             << computation_layout.ToString();
1956   }
1957   return Status::OK();
1958 }
1959 
ClearComputationLayouts(HloComputation * computation)1960 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1961   // Clear existing layouts of the instructions.  All layouts must be assigned
1962   // by the LayoutAssignment pass, except for those on parameters, the
1963   // computation result, and a couple special cases. The former two are
1964   // specified in computation_layout.  Clearing the layouts here avoids hiding
1965   // potential bugs in the layout assignment pass that may accidentally use the
1966   // existing layout.
1967   for (HloInstruction* instruction : computation->instructions()) {
1968     if (instruction->opcode() == HloOpcode::kBitcast) {
1969       // bitcasts are inherently layout sensitive and so a bitcast instruction
1970       // present in the IR before layout assignment is a bug.
1971       return InternalError(
1972           "Unexpected bitcast operation seen during layout assignment: %s.",
1973           instruction->ToString());
1974     }
1975     // Some instructions carry mandatory layouts in their shape.
1976     if (instruction->opcode() != HloOpcode::kInfeed &&
1977         !IsLayoutConstrainedCustomCall(instruction) &&
1978         !IsLayoutConstrainedCollective(instruction)) {
1979       LayoutUtil::ClearLayout(instruction->mutable_shape());
1980     }
1981   }
1982   return Status::OK();
1983 }
1984 
RunOnComputation(ComputationLayout * computation_layout,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1985 Status LayoutAssignment::RunOnComputation(
1986     ComputationLayout* computation_layout, HloComputation* computation,
1987     ChannelLayoutConstraints* channel_constraints) {
1988   VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1989           << ")";
1990 
1991   // Must be run before clearing layouts.
1992   TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1993 
1994   TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1995   if (computation_layout != nullptr) {
1996     auto it = computation_layouts_.find(computation);
1997     if (it == computation_layouts_.end()) {
1998       VLOG(2) << "  New ComputationLayout = " << computation_layout->ToString();
1999       computation_layouts_.emplace(computation, *computation_layout);
2000     } else {
2001       TF_RET_CHECK(computation_layout == &it->second ||
2002                    computation_layout == entry_computation_layout_);
2003       VLOG(2) << "  Existing ComputationLayout = "
2004               << computation_layout->ToString();
2005     }
2006   } else {
2007     VLOG(2) << "  No ComputationLayout specified (will be calculated)";
2008   }
2009 
2010   // Construct LayoutConstraints with all layout constraints of the computation.
2011   LayoutConstraints constraints(*points_to_analysis_, computation);
2012 
2013   // Add constraints required for correctness on all backends (eg, entry
2014   // parameter layout constraints).
2015   TF_RETURN_IF_ERROR(AddMandatoryConstraints(
2016       computation_layout, channel_constraints, computation, &constraints));
2017 
2018   // Add any backend-specific constraints.
2019   TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
2020 
2021   // Propagates layouts from mandatory and backend constraints.
2022   TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
2023 
2024   // Prior to applying default layouts, we take note of all HLO instructions
2025   // which lack a layout constraint.
2026   for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
2027     unconstrained_layout_instructions_.insert(
2028         points_to_analysis_->GetBuffer(buffer_id).instruction());
2029   }
2030 
2031   // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
2032   // layout and propagate the change.
2033   while (!constraints.unconstrained_buffer_ids().empty()) {
2034     int unconstrained_count = constraints.unconstrained_buffer_ids().size();
2035 
2036     // Arbitrarily pick the first unconstrained buffer and give it the default
2037     // layout (or the literal layout, in case of constants). By construction
2038     // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
2039     const LogicalBuffer& buffer = points_to_analysis_->GetBuffer(
2040         *constraints.unconstrained_buffer_ids().begin());
2041     const HloInstruction* instruction = buffer.instruction();
2042     Layout new_layout =
2043         instruction->opcode() == HloOpcode::kConstant
2044             ? ShapeUtil::GetSubshape(instruction->literal().shape(),
2045                                      buffer.index())
2046                   .layout()
2047             : GetUnconstrainedLayout(buffer);
2048     TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
2049                                                    /*mandatory=*/false));
2050 
2051     TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
2052 
2053     // To verify progress has been made, check that the number of unconstrained
2054     // buffers has been reduced.
2055     CHECK_LT(constraints.unconstrained_buffer_ids().size(),
2056              unconstrained_count);
2057   }
2058   // All logical buffers should have constraints at this point. All that
2059   // remains is assign the constraints to the buffers and infer layouts for
2060   // aliased buffers.
2061   TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
2062 
2063   // If the computation layout wasn't specified, now it is the time to compute
2064   // it according to the parameters and root instruction layouts.
2065   // This allows the first pass through this API to record the best flowing
2066   // layout to parameters and root instruction.
2067   if (computation_layout == nullptr) {
2068     TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
2069   }
2070 
2071   // Record the layouts assigned for any communication ops in
2072   // channel_constraints so that they are constrained for future modules.
2073   if (channel_constraints != nullptr) {
2074     TF_RETURN_IF_ERROR(
2075         ConstrainChannelLayouts(computation, channel_constraints));
2076   }
2077 
2078   // Copy the root instruction's result if its layout does not match the result
2079   // layout constraint.
2080   if (constraints.ResultLayout() != nullptr &&
2081       !constraints.ResultLayout()->MatchesLayoutInShape(
2082           computation->root_instruction()->shape(),
2083           /*minor_to_major_only=*/true)) {
2084     if (conditional_mismatch_.count(computation) > 0) {
2085       *FindOrDie(computation_layouts_, computation).mutable_result_layout() =
2086           FindOrDie(conditional_mismatch_, computation).result_layout();
2087     }
2088     TF_ASSIGN_OR_RETURN(
2089         HloInstruction * new_root,
2090         CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
2091                                 computation->root_instruction()));
2092     computation->set_root_instruction(new_root);
2093   }
2094   return Status::OK();
2095 }
2096 
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)2097 Status LayoutAssignment::ConstrainChannelLayouts(
2098     HloComputation* computation,
2099     ChannelLayoutConstraints* channel_constraints) {
2100   auto get_channel_constraints = [&](const HloInstruction* instruction) {
2101     return IsHostSendRecv(instruction) ? &host_channel_constraints_
2102                                        : channel_constraints;
2103   };
2104   // We go through the kRecvDone before. These must either impose their layout,
2105   // or find a matching one already existing (ConstrainChannel() returns
2106   // nullptr).
2107   for (HloInstruction* instruction : computation->instructions()) {
2108     if (instruction->opcode() == HloOpcode::kRecvDone) {
2109       const Layout* layout =
2110           get_channel_constraints(instruction)
2111               ->ConstrainChannel(
2112                   *instruction->channel_id(),
2113                   ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
2114       TF_RET_CHECK(layout == nullptr)
2115           << instruction->ToString()
2116           << " cannot constrain layout as it was set to "
2117           << LayoutUtil::HumanString(*layout);
2118     }
2119   }
2120   // After that we go through the kSend. These are likely going to have a kCopy
2121   // as operand (otherwise we add it), so in case the constrained layout does
2122   // not match, we can change the kCopy layout (and the kSend one as well).
2123   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
2124     if (instruction->opcode() == HloOpcode::kSend) {
2125       HloInstruction* operand = instruction->mutable_operand(0);
2126       get_channel_constraints(instruction)
2127           ->ConstrainChannel(*instruction->channel_id(),
2128                              operand->shape().layout());
2129     } else if (instruction->IsCrossModuleAllReduce()) {
2130       get_channel_constraints(instruction)
2131           ->ConstrainChannel(instruction->channel_id().value(),
2132                              instruction->shape().layout());
2133     }
2134   }
2135   return Status::OK();
2136 }
2137 
PropagateMemorySpace(HloModule * module)2138 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
2139   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
2140   for (const auto& buffer : alias_analysis->buffers()) {
2141     // First go through values to collect the memory spaces.
2142     int64_t buffer_memory_space = Layout::kDefaultMemorySpace;
2143     for (auto value : buffer.values()) {
2144       const Shape& defining_shape = value->defining_position().shape();
2145       int64_t memory_space = defining_shape.layout().memory_space();
2146       if (memory_space != Layout::kDefaultMemorySpace) {
2147         if (buffer_memory_space != Layout::kDefaultMemorySpace &&
2148             memory_space != buffer_memory_space) {
2149           return InternalError(
2150               "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
2151               buffer.id(), value->ToShortString(), buffer_memory_space,
2152               memory_space);
2153         }
2154         buffer_memory_space = memory_space;
2155       }
2156     }
2157 
2158     // If we encounter a memory space other than the default, then propagate all
2159     // the positions with the buffer's memory space.
2160     if (buffer_memory_space != Layout::kDefaultMemorySpace) {
2161       for (auto value : buffer.values()) {
2162         for (auto& position : value->positions()) {
2163           Shape* shape = ShapeUtil::GetMutableSubshape(
2164               position.instruction->mutable_shape(), position.index);
2165           shape->mutable_layout()->set_memory_space(buffer_memory_space);
2166         }
2167       }
2168     }
2169   }
2170   return Status::OK();
2171 }
2172 
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2173 Status LayoutAssignment::PropagateComputationLayouts(
2174     HloComputation* computation, ComputationLayout* computation_layout) {
2175   ComputationLayout computed_computation_layout(
2176       computation->ComputeProgramShape(),
2177       /*ignore_layouts=*/false);
2178   for (int64_t i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2179     ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2180     bool needs_assign = false;
2181     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2182         param_layout->shape(),
2183         [&](const Shape& subshape, const ShapeIndex& shape_index) {
2184           if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2185             return Status::OK();
2186           }
2187           if (!subshape.has_layout()) {
2188             needs_assign = true;
2189             return Status::OK();
2190           }
2191           const auto& computed_subshape = ShapeUtil::GetSubshape(
2192               computed_computation_layout.parameter_shape(i), shape_index);
2193           if (subshape.layout() != computed_subshape.layout()) {
2194             return InternalError(
2195                 "Assigned parameter shape %s does not match layout of "
2196                 "computation shape: %s",
2197                 computed_computation_layout.ToString(),
2198                 computation_layout->ToString());
2199           }
2200           return Status::OK();
2201         }));
2202     if (needs_assign) {
2203       VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2204               << computation->name() << ": "
2205               << computed_computation_layout.parameter_layout(i).ToString();
2206       *param_layout = computed_computation_layout.parameter_layout(i);
2207     }
2208   }
2209   ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2210   if (!result_layout->LayoutIsSet()) {
2211     VLOG(4) << "Assigning result layout of computation " << computation->name()
2212             << ": " << computed_computation_layout.result_layout().ToString();
2213     *result_layout = computed_computation_layout.result_layout();
2214   } else {
2215     TF_RET_CHECK(
2216         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2217             computed_computation_layout.result_layout().shape(),
2218             result_layout->shape()));
2219   }
2220   return Status::OK();
2221 }
2222 
Run(HloModule * module)2223 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
2224   VLOG(2) << "Running layout assignment on module " << module->name();
2225   TF_RETURN_IF_ERROR(Init());
2226   call_graph_ = CallGraph::Build(module);
2227   auto computations = module->computations();
2228 
2229   // Add copy to the operand of Send instructions, since we cannot call
2230   // SetOperandLayout on Send instructions as it aliases its input to the
2231   // output.
2232   //
2233   // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2234   // operand buffers that aliases with the output.
2235   for (HloComputation* computation : module->computations()) {
2236     for (HloInstruction* instruction :
2237          computation->MakeInstructionPostOrder()) {
2238       if (instruction->opcode() == HloOpcode::kSend) {
2239         TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2240       }
2241     }
2242   }
2243 
2244   // Clone Conditional computations with multiple callsites.
2245   for (HloComputation* computation : computations) {
2246     CallGraphNode& node = call_graph_->GetNode(computation);
2247     if (node.caller_callsites().size() == 1) {
2248       continue;
2249     }
2250     if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2251           return caller.instruction()->opcode() == HloOpcode::kConditional;
2252         })) {
2253       continue;
2254     }
2255     for (int64_t i = 0; i < node.caller_callsites().size() - 1; ++i) {
2256       HloInstruction* caller = node.caller_callsites()[i].instruction();
2257       if (caller->opcode() == HloOpcode::kConditional) {
2258         for (int64_t k = 0; k < caller->branch_count(); ++k) {
2259           if (computation == caller->branch_computation(k)) {
2260             caller->set_branch_computation(
2261                 k, module->AddEmbeddedComputation(computation->Clone()));
2262             break;
2263           }
2264         }
2265       }
2266     }
2267   }
2268 
2269   // Verify computation layout is sane.
2270   const HloComputation* entry = module->entry_computation();
2271   TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2272                entry->num_parameters());
2273   for (int64_t i = 0; i < entry->num_parameters(); ++i) {
2274     TF_RET_CHECK(
2275         ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2276                               entry->parameter_instruction(i)->shape()));
2277   }
2278   TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2279                                      entry->root_instruction()->shape()));
2280   // We do two passes. The first one we pass a nullptr ComputationLayout to
2281   // the RunOnComputation() calls (for non entry computations), and we register
2282   // the ComputationLayout which are naturally flowing in DFS fashion to the
2283   // parameters and root instruction.
2284   // Walking in DFS mode though, means that we can end up with incorrect layouts
2285   // when seen from an outer instruction, which has across-computation
2286   // constraints to impose.
2287   // For example, the kWhile instruction needs to enforce the same layouts for
2288   // the parameters and root of the body, as well as the condition parameters.
2289   // Similarly, the kConditional instruction needs to enforce the same layouts
2290   // for the root of the true and false computations.
2291   // So in the first pass, while allowing the layouts to flow to parameters and
2292   // root, we also fix up the eventually inconsistent ComputationLayout, which
2293   // will be then made mandatory by the second pass.
2294   for (int64_t i = 0; i < 2; ++i) {
2295     VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2296     TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
2297     TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2298                         TuplePointsToAnalysis::Run(module));
2299     points_to_analysis_ = std::move(points_to_analysis);
2300     auto computations_to_work = module->MakeComputationPostOrder();
2301     // If the reverse_comptation_order_ flag is set, reverse the ordering of
2302     // traversing computations, to generate an alternative layout assignment.
2303     if (reverse_computation_order_ && !computations_to_work.empty()) {
2304       absl::c_reverse(computations_to_work);
2305 
2306       VLOG(2) << "reversing traversal order for computation:";
2307     }
2308     for (auto* computation : computations_to_work) {
2309       if (computation->IsFusionComputation()) {
2310         continue;
2311       }
2312       if (computation == module->entry_computation()) {
2313         TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_,
2314                                             module->entry_computation(),
2315                                             channel_layout_constraints_));
2316       } else {
2317         ComputationLayout* computation_layout =
2318             (i == 1 && conditional_mismatch_.count(computation) == 0) ||
2319                     (reverse_computation_order_ &&
2320                      computation_layouts_.find(computation) !=
2321                          computation_layouts_.end())
2322                 ? &FindOrDie(computation_layouts_, computation)
2323                 : nullptr;
2324         TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation,
2325                                             channel_layout_constraints_));
2326       }
2327     }
2328     VLOG(2) << "After running #" << i << ":" << module->ToString();
2329   }
2330   TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2331                                                  entry_computation_layout_));
2332 
2333   TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2334 
2335   TF_RETURN_IF_ERROR(CheckLayouts(module));
2336 
2337   // All layouts are reset then reassigned by this pass.
2338   return true;
2339 }
2340 
2341 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2342 bool LayoutAssignment::InstructionCanChangeLayout(
2343     const HloInstruction* instruction) {
2344   switch (instruction->opcode()) {
2345     case HloOpcode::kAbs:
2346     case HloOpcode::kAdd:
2347     case HloOpcode::kAddDependency:
2348     case HloOpcode::kAnd:
2349     case HloOpcode::kAtan2:
2350     case HloOpcode::kBitcastConvert:
2351     case HloOpcode::kCeil:
2352     case HloOpcode::kClamp:
2353     case HloOpcode::kClz:
2354     case HloOpcode::kCompare:
2355     case HloOpcode::kComplex:
2356     case HloOpcode::kConcatenate:
2357     case HloOpcode::kConditional:
2358     case HloOpcode::kConvert:
2359     case HloOpcode::kCos:
2360     case HloOpcode::kAllGather:
2361     case HloOpcode::kAllGatherStart:
2362     case HloOpcode::kAllGatherDone:
2363     case HloOpcode::kAllToAll:
2364     case HloOpcode::kCollectivePermute:
2365     case HloOpcode::kDivide:
2366     case HloOpcode::kDynamicSlice:
2367     case HloOpcode::kDynamicUpdateSlice:
2368     case HloOpcode::kExp:
2369     case HloOpcode::kExpm1:
2370     case HloOpcode::kFft:
2371     case HloOpcode::kFloor:
2372     case HloOpcode::kImag:
2373     case HloOpcode::kIsFinite:
2374     case HloOpcode::kLog:
2375     case HloOpcode::kLog1p:
2376     case HloOpcode::kLogistic:
2377     case HloOpcode::kMap:
2378     case HloOpcode::kMaximum:
2379     case HloOpcode::kMinimum:
2380     case HloOpcode::kMultiply:
2381     case HloOpcode::kNegate:
2382     case HloOpcode::kNot:
2383     case HloOpcode::kOr:
2384     case HloOpcode::kXor:
2385     case HloOpcode::kPad:
2386     case HloOpcode::kPower:
2387     case HloOpcode::kReal:
2388     case HloOpcode::kReducePrecision:
2389     case HloOpcode::kReduceWindow:
2390     case HloOpcode::kRemainder:
2391     case HloOpcode::kReverse:
2392     case HloOpcode::kRoundNearestAfz:
2393     case HloOpcode::kRsqrt:
2394     case HloOpcode::kScatter:
2395     case HloOpcode::kSelect:
2396     case HloOpcode::kSelectAndScatter:
2397     case HloOpcode::kShiftLeft:
2398     case HloOpcode::kShiftRightArithmetic:
2399     case HloOpcode::kShiftRightLogical:
2400     case HloOpcode::kSign:
2401     case HloOpcode::kSin:
2402     case HloOpcode::kSlice:
2403     case HloOpcode::kSort:
2404     case HloOpcode::kSqrt:
2405     case HloOpcode::kCbrt:
2406     case HloOpcode::kSubtract:
2407     case HloOpcode::kTanh:
2408     case HloOpcode::kPopulationCount:
2409     case HloOpcode::kTriangularSolve:
2410     case HloOpcode::kCholesky:
2411     case HloOpcode::kTupleSelect:
2412     case HloOpcode::kWhile:
2413     case HloOpcode::kSetDimensionSize:
2414     // AllReduce is variadic so it needs to be careful to assign the same layout
2415     // to the corresponding input argument and Tuple index.
2416     case HloOpcode::kAllReduce:
2417     case HloOpcode::kReduceScatter:
2418     case HloOpcode::kAllReduceStart:
2419     case HloOpcode::kAllReduceDone:
2420       return false;
2421     case HloOpcode::kBatchNormGrad:
2422     case HloOpcode::kBatchNormInference:
2423     case HloOpcode::kBatchNormTraining:
2424     case HloOpcode::kBitcast:
2425     case HloOpcode::kBroadcast:
2426     case HloOpcode::kCall:
2427     case HloOpcode::kCollectivePermuteStart:
2428     case HloOpcode::kCollectivePermuteDone:
2429     case HloOpcode::kConstant:
2430     case HloOpcode::kConvolution:
2431     case HloOpcode::kCopy:
2432     case HloOpcode::kCopyStart:
2433     case HloOpcode::kCopyDone:
2434     case HloOpcode::kCustomCall:
2435     case HloOpcode::kDomain:
2436     case HloOpcode::kDot:
2437     case HloOpcode::kFusion:
2438     case HloOpcode::kGather:
2439     case HloOpcode::kGetTupleElement:
2440     case HloOpcode::kInfeed:
2441     case HloOpcode::kIota:
2442     case HloOpcode::kOutfeed:
2443     case HloOpcode::kParameter:
2444     case HloOpcode::kPartitionId:
2445     case HloOpcode::kRecv:
2446     case HloOpcode::kRecvDone:
2447     case HloOpcode::kReduce:
2448     case HloOpcode::kReplicaId:
2449     case HloOpcode::kReshape:
2450     case HloOpcode::kDynamicReshape:
2451     case HloOpcode::kRng:
2452     case HloOpcode::kRngBitGenerator:
2453     case HloOpcode::kRngGetAndUpdateState:
2454     case HloOpcode::kSend:
2455     case HloOpcode::kSendDone:
2456     case HloOpcode::kAfterAll:
2457     case HloOpcode::kTrace:
2458     case HloOpcode::kTranspose:
2459     case HloOpcode::kTuple:
2460     case HloOpcode::kGetDimensionSize:
2461       return true;
2462   }
2463 }
2464 
InstructionCanChangeLayoutInstance(const HloInstruction * instruction)2465 bool LayoutAssignment::InstructionCanChangeLayoutInstance(
2466     const HloInstruction* instruction) {
2467   return InstructionCanChangeLayout(instruction);
2468 }
2469 
2470 /* static */
IsAtMostRank1(const Shape & shape)2471 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2472   if (shape.IsArray()) {
2473     return shape.rank() <= 1;
2474   }
2475   return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2476     return IsAtMostRank1(subshape);
2477   });
2478 }
2479 
Init()2480 Status LayoutAssignment::Init() {
2481   computation_layouts_.clear();
2482   conditional_mismatch_.clear();
2483   *entry_computation_layout_ = saved_entry_computation_layout_;
2484   return Status::OK();
2485 }
2486 
ClearPreviousPassSideEffects(HloModule * module)2487 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2488   VLOG(5) << "Clearing previous side effects";
2489   // Clear all the copies which have been added, and all the related
2490   // instructions (like GTE and tuples).
2491   int64_t removed_copies = 0;
2492   for (HloComputation* computation : module->computations()) {
2493     for (HloInstruction* instruction :
2494          computation->MakeInstructionPostOrder()) {
2495       if (instruction->opcode() == HloOpcode::kCopy &&
2496           added_copies_.contains(instruction)) {
2497         VLOG(5) << "Removing added copy: " << instruction->ToString();
2498         TF_RETURN_IF_ERROR(
2499             instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2500         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2501         ++removed_copies;
2502       }
2503     }
2504   }
2505   added_copies_.clear();
2506   unconstrained_layout_instructions_.clear();
2507   if (removed_copies > 0) {
2508     TupleSimplifier tuple_simplifier;
2509     HloDCE dce;
2510     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2511     TF_RETURN_IF_ERROR(dce.Run(module).status());
2512     call_graph_ = CallGraph::Build(module);
2513   }
2514   return Status::OK();
2515 }
2516 
AddCopyForOperand(HloInstruction * instruction,int64_t operand_number)2517 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2518                                            int64_t operand_number) {
2519   HloInstruction* operand = instruction->mutable_operand(operand_number);
2520   if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2521     HloInstruction* copy =
2522         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2523             operand->shape(), HloOpcode::kCopy, operand));
2524     SetupCopiedInstruction(*operand, copy, {});
2525     LayoutUtil::ClearLayout(copy->mutable_shape());
2526     TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2527   }
2528   return Status::OK();
2529 }
2530 
2531 }  // namespace xla
2532