• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/service/call_graph.h"
38 #include "tensorflow/compiler/xla/service/computation_layout.h"
39 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_computation.h"
42 #include "tensorflow/compiler/xla/service/hlo_dce.h"
43 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/logical_buffer.h"
47 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
48 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
49 #include "tensorflow/compiler/xla/shape_layout.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/statusor.h"
53 #include "tensorflow/compiler/xla/types.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/xla_data.pb.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/status.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/platform/protobuf.h"
60 
61 namespace xla {
62 
operator <<(std::ostream & out,const LayoutConstraint & constraint)63 std::ostream& operator<<(std::ostream& out,
64                          const LayoutConstraint& constraint) {
65   out << constraint.ToString();
66   return out;
67 }
68 
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)69 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
70                                                const LogicalBuffer& buffer,
71                                                bool mandatory, bool dfs)
72     : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
73   CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
74 }
75 
ToString() const76 string BufferLayoutConstraint::ToString() const {
77   return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
78                          LayoutUtil::HumanString(layout_));
79 }
80 
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)81 OperandLayoutConstraint::OperandLayoutConstraint(
82     const ShapeLayout& shape_layout, const HloInstruction* instruction,
83     int64 operand_no, bool mandatory, bool dfs)
84     : LayoutConstraint(mandatory, dfs),
85       shape_layout_(shape_layout),
86       instruction_(instruction),
87       operand_no_(operand_no) {
88   CHECK(shape_layout_.LayoutIsSet());
89   CHECK(ShapeUtil::Compatible(shape_layout.shape(),
90                               instruction->operand(operand_no)->shape()))
91       << shape_layout.shape() << " is not compatible with "
92       << instruction->operand(operand_no)->shape() << " (for operand "
93       << operand_no << " of instruction " << instruction->ToString() << ")";
94 }
95 
ToString() const96 string OperandLayoutConstraint::ToString() const {
97   return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
98                          instruction_->name(), operand_no_,
99                          shape_layout_.ToString());
100 }
101 
ToString() const102 string ResultLayoutConstraint::ToString() const {
103   return absl::StrFormat("ResultLayoutConstraint: %s",
104                          shape_layout_.ToString());
105 }
106 
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)107 LayoutConstraints::LayoutConstraints(
108     const TuplePointsToAnalysis& points_to_analysis,
109     HloComputation* computation)
110     : points_to_analysis_(points_to_analysis), computation_(computation) {
111   // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
112   for (HloInstruction* inst : computation_->instructions()) {
113     points_to_analysis_.GetPointsToSet(inst).ForEachElement(
114         [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
115           for (const LogicalBuffer* buffer : buffers) {
116             // The points to analysis is computed per module, restrict
117             // constraints to array buffers in this computation.
118             if (buffer->IsArray() &&
119                 buffer->instruction()->parent() == computation) {
120               unconstrained_buffer_ids_.insert(buffer->id());
121             }
122           }
123         });
124   }
125 }
126 
GetBufferSet(const HloInstruction * instruction) const127 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
128     const HloInstruction* instruction) const {
129   auto it = buffer_sets_cache_.find(instruction);
130   if (it != buffer_sets_cache_.end()) {
131     return it->second.get();
132   }
133   auto& buffer_set =
134       buffer_sets_cache_
135           .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
136           .first->second;
137   const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
138   points_to_set.ForEachElement(
139       [&buffer_set](const ShapeIndex& /*index*/,
140                     const PointsToSet::BufferList& buffers) {
141         buffer_set->insert(buffers.begin(), buffers.end());
142       });
143   return buffer_set.get();
144 }
145 
OperandBufferForwarded(const HloInstruction * instruction,int64 operand_no) const146 bool LayoutConstraints::OperandBufferForwarded(
147     const HloInstruction* instruction, int64 operand_no) const {
148   // The operand is potentially forwarded if the intersection of points-to sets
149   // of the operand and the instruction is non-empty.
150   PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
151   PointsToSet::BufferSet* operand_buffers =
152       GetBufferSet(instruction->operand(operand_no));
153   return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
154     return operand_buffers->count(b) > 0;
155   });
156 }
157 
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)158 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
159                                           const LogicalBuffer& buffer,
160                                           bool mandatory, bool dfs) {
161   VLOG(3) << "SetBufferLayout : " << buffer << " : "
162           << LayoutUtil::HumanString(layout);
163 
164   TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
165   if (!buffer.IsArray()) {
166     return FailedPrecondition(
167         "Layout of buffer %s cannot be constrained because buffer is not "
168         "array-shaped, has shape: %s",
169         buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
170   }
171   TF_RETURN_IF_ERROR(
172       LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
173 
174   auto iter = buffer_constraints_.find(&buffer);
175   if (iter != buffer_constraints_.end()) {
176     const BufferLayoutConstraint& curr_constraint = iter->second;
177     if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
178       // New constraint matches existing constraint. Nothing to do.
179       return Status::OK();
180     }
181     if (curr_constraint.mandatory()) {
182       if (!mandatory) {
183         VLOG(3) << "Buffer" << buffer
184                 << " already has a mandatory layout constrain, skipping";
185         return Status::OK();
186       }
187       return FailedPrecondition(
188           "Buffer %s already has the layout constraint %s, cannot add "
189           "incompatible constraint %s",
190           buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
191           LayoutUtil::HumanString(layout));
192     }
193     iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
194   } else {
195     TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
196         << buffer.ToString();
197     iter = buffer_constraints_
198                .insert(std::make_pair(
199                    &buffer,
200                    BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
201                .first;
202   }
203   added_constraints_.push_back(&iter->second);
204   return Status::OK();
205 }
206 
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)207 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
208                                            const HloInstruction* instruction,
209                                            int64 operand_no, bool mandatory,
210                                            bool dfs) {
211   VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
212           << operand_no << " : "
213           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
214 
215   const OperandLayoutConstraint* curr_shape_layout =
216       GetOperandLayoutConstraint(instruction, operand_no);
217   if (curr_shape_layout != nullptr) {
218     if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
219             shape_with_layout, /*minor_to_major_only=*/true)) {
220       // New constraint matches existing constraint. Nothing to do.
221       return Status::OK();
222     }
223     if (curr_shape_layout->mandatory()) {
224       return FailedPrecondition(
225           "Operand %d of instruction %s already has a layout constraint "
226           "%s, cannot add incompatible constraint %s",
227           operand_no, instruction->name(),
228           curr_shape_layout->shape_layout().ToString(),
229           ShapeUtil::HumanStringWithLayout(shape_with_layout));
230     }
231   }
232 
233   // If any buffers in the operand occur in the output of the instruction, then
234   // return an error. This case is not handled because such a constraint changes
235   // layouts beyond this immediate use and is complicated to handle.
236   if (OperandBufferForwarded(instruction, operand_no)) {
237     return FailedPrecondition(
238         "Cannot constraint layout of operand %d of instruction %s "
239         "because instruction forwards operand's LogicalBuffer(s)",
240         operand_no, instruction->name());
241   }
242 
243   auto key = std::make_pair(instruction, operand_no);
244   auto iter = operand_constraints_.find(key);
245   if (iter == operand_constraints_.end()) {
246     auto pair = std::make_pair(
247         key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
248                                      instruction, operand_no, mandatory, dfs));
249     iter = operand_constraints_.insert(pair).first;
250   } else {
251     iter->second =
252         OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
253                                 operand_no, mandatory, dfs);
254   }
255   added_constraints_.push_back(&iter->second);
256 
257   return Status::OK();
258 }
259 
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)260 Status LayoutConstraints::SetArrayOperandLayout(
261     const Layout& layout, const HloInstruction* instruction, int64 operand_no,
262     bool mandatory, bool dfs) {
263   const HloInstruction* operand = instruction->operand(operand_no);
264   TF_RET_CHECK(operand->shape().IsArray());
265   Shape shape(operand->shape());
266   *shape.mutable_layout() = layout;
267   TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
268   return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
269 }
270 
SetResultLayout(const Shape & shape_with_layout,bool dfs)271 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
272                                           bool dfs) {
273   VLOG(3) << "SetResultLayout : "
274           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
275 
276   const ShapeLayout* curr_shape_layout = ResultLayout();
277   if (curr_shape_layout != nullptr) {
278     if (!curr_shape_layout->MatchesLayoutInShape(
279             shape_with_layout, /*minor_to_major_only=*/true)) {
280       return FailedPrecondition(
281           "Result of computation %s already has the layout constraint %s, "
282           "cannot add incompatible constraint %s",
283           computation_->name(), curr_shape_layout->ToString(),
284           ShapeUtil::HumanStringWithLayout(shape_with_layout));
285     }
286     // New constraint matches existing constraint. Nothing to do.
287     return Status::OK();
288   }
289   result_constraint_.reset(
290       new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
291   added_constraints_.push_back(result_constraint_.get());
292 
293   return Status::OK();
294 }
295 
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs)296 Status LayoutConstraints::SetInstructionLayout(
297     const Shape& shape_with_layout, const HloInstruction* instruction,
298     bool mandatory, bool dfs) {
299   VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
300           << ShapeUtil::HumanStringWithLayout(shape_with_layout);
301 
302   if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
303     return FailedPrecondition(
304         "Instruction %s of shape %s cannot be assigned incompatible layout %s",
305         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
306         ShapeUtil::HumanStringWithLayout(shape_with_layout));
307   }
308 
309   // Create a BufferLayoutConstraint for each array shape in the output of the
310   // instruction.
311   return ShapeUtil::ForEachSubshapeWithStatus(
312       shape_with_layout,
313       [this, instruction, mandatory](const Shape& subshape,
314                                      const ShapeIndex& index) -> Status {
315         // The precondition for this method is that the instruction defines all
316         // buffers in its output.
317         auto buffers =
318             points_to_analysis_.GetPointsToSet(instruction).element(index);
319         CHECK_EQ(1, buffers.size());
320         CHECK_EQ(buffers[0]->instruction(), instruction);
321 
322         if (subshape.IsArray() && subshape.has_layout()) {
323           return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
324         } else {
325           return Status::OK();
326         }
327       });
328 }
329 
BufferLayout(const LogicalBuffer & buffer) const330 const Layout* LayoutConstraints::BufferLayout(
331     const LogicalBuffer& buffer) const {
332   if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
333     return &constraint->layout();
334   }
335   return nullptr;
336 }
337 
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const338 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
339     const LogicalBuffer& buffer) const {
340   auto it = buffer_constraints_.find(&buffer);
341   return it == buffer_constraints_.end() ? nullptr : &it->second;
342 }
343 
OperandLayout(const HloInstruction * instruction,int64 operand_no) const344 const ShapeLayout* LayoutConstraints::OperandLayout(
345     const HloInstruction* instruction, int64 operand_no) const {
346   if (const auto* constraint =
347           GetOperandLayoutConstraint(instruction, operand_no)) {
348     return &constraint->shape_layout();
349   }
350   return nullptr;
351 }
352 
GetOperandLayoutConstraint(const HloInstruction * instruction,int64 operand_no) const353 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
354     const HloInstruction* instruction, int64 operand_no) const {
355   auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
356   return it == operand_constraints_.end() ? nullptr : &it->second;
357 }
358 
ResultLayout() const359 const ShapeLayout* LayoutConstraints::ResultLayout() const {
360   return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
361 }
362 
ToString() const363 string LayoutConstraints::ToString() const {
364   string output;
365   absl::StrAppend(&output, "LayoutConstraints for computation ",
366                   computation_->name(), ":\n");
367   for (auto* instruction : computation_->MakeInstructionPostOrder()) {
368     absl::StrAppend(&output, "  ", instruction->ToShortString(), "\n");
369     for (int64 i = 0; i < instruction->operand_count(); ++i) {
370       if (OperandLayout(instruction, i) != nullptr) {
371         absl::StrAppend(&output, "    operand (", i,
372                         "): ", OperandLayout(instruction, i)->ToString(), "\n");
373       }
374     }
375     for (const LogicalBuffer* buffer :
376          points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
377       if (BufferLayout(*buffer) != nullptr) {
378         absl::StrAppend(&output, "    ", buffer->ToString(), " : ",
379                         LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
380       }
381     }
382   }
383 
384   if (ResultLayout() != nullptr) {
385     absl::StrAppend(&output, "  => ", ResultLayout()->ToString(), "\n");
386   }
387   return output;
388 }
389 
390 namespace {
391 
IsHostSendRecv(const HloInstruction * instruction)392 bool IsHostSendRecv(const HloInstruction* instruction) {
393   const HloSendRecvInstruction* send_recv_instr =
394       DynCast<HloSendRecvInstruction>(instruction);
395   return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
396 }
397 
398 }  // namespace
399 
BuildHostChannelConstraints(HloComputation * computation)400 Status LayoutAssignment::BuildHostChannelConstraints(
401     HloComputation* computation) {
402   for (auto* instruction : computation->instructions()) {
403     const HloSendRecvInstruction* send_recv_instr =
404         DynCast<HloSendRecvInstruction>(instruction);
405     if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
406       continue;
407     }
408 
409     // For host transfers the Send and Recv instruction carry the layout.
410     if (instruction->opcode() == HloOpcode::kSend ||
411         instruction->opcode() == HloOpcode::kRecv) {
412       const Shape& data_shape =
413           ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
414       TF_RET_CHECK(data_shape.IsArray());
415       TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
416       const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
417           *send_recv_instr->channel_id(), data_shape.layout());
418       TF_RET_CHECK(prev_layout == nullptr)
419           << "Cannot constrain host transfer layout as it was set to "
420           << LayoutUtil::HumanString(*prev_layout) << ": "
421           << send_recv_instr->ToString();
422     }
423   }
424   return Status::OK();
425 }
426 
427 namespace {
428 
IsLayoutConstrainedCustomCall(HloInstruction * instruction)429 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
430   const HloCustomCallInstruction* custom_call =
431       DynCast<HloCustomCallInstruction>(instruction);
432   return custom_call != nullptr && custom_call->layout_constrained();
433 }
434 
IsLayoutConstrainedAllReduce(const HloInstruction * instruction)435 bool IsLayoutConstrainedAllReduce(const HloInstruction* instruction) {
436   const HloAllReduceInstruction* all_reduce =
437       DynCast<HloAllReduceInstruction>(instruction);
438   return all_reduce != nullptr && all_reduce->constrain_layout();
439 }
440 
441 }  // namespace
442 
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)443 Status LayoutAssignment::AddMandatoryConstraints(
444     const ComputationLayout* computation_layout,
445     ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
446     LayoutConstraints* constraints) {
447   VLOG(3) << "Adding mandatory layout constraints to computation "
448           << computation->name();
449 
450   auto get_channel_constraints = [&](const HloInstruction* instruction) {
451     return IsHostSendRecv(instruction) ? &host_channel_constraints_
452                                        : channel_constraints;
453   };
454 
455   // Constrain layouts of instructions which define values with pre-existing
456   // layouts.
457   for (auto* instruction : computation->instructions()) {
458     if (instruction->opcode() == HloOpcode::kInfeed) {
459       // Infeed layouts must match the layout of the original inserted
460       // instruction.
461       // TODO(b/31425034): Change infeeds to be more like parameters, with
462       // shapes in the ComputationLayout.
463       TF_RETURN_IF_ERROR(
464           constraints->SetInstructionLayout(instruction->shape(), instruction));
465     } else if (instruction->opcode() == HloOpcode::kOutfeed) {
466       // Constrain the input to the Outfeed instruction to be the expected
467       // layout of the Outfeed.
468       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
469           instruction->outfeed_shape(), instruction, 0));
470     } else if (instruction->opcode() == HloOpcode::kParameter) {
471       if (computation_layout != nullptr) {
472         const ShapeLayout& parameter_layout =
473             computation_layout->parameter_layout(
474                 instruction->parameter_number());
475         // Parameter layouts must match the respective layout in
476         // ComputationLayout, if there is one.
477         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
478             parameter_layout.shape(), instruction));
479       }
480     } else if (IsLayoutConstrainedCustomCall(instruction)) {
481       const HloCustomCallInstruction* custom_call =
482           DynCast<HloCustomCallInstruction>(instruction);
483       TF_RETURN_IF_ERROR(
484           constraints->SetInstructionLayout(custom_call->shape(), custom_call));
485       for (int64 i = 0; i < custom_call->operand_count(); ++i) {
486         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
487             custom_call->operand_shapes_with_layout()[i], custom_call, i));
488       }
489     } else if (instruction->opcode() == HloOpcode::kSend ||
490                instruction->opcode() == HloOpcode::kRecv) {
491       CHECK(get_channel_constraints(instruction))
492           << "Multi-module layout assignment requires ChannelLayoutConstraints";
493       int64 channel_id = *instruction->channel_id();
494       if (!get_channel_constraints(instruction)
495                ->IsChannelConstrained(channel_id)) {
496         continue;
497       }
498       if (instruction->opcode() == HloOpcode::kSend) {
499         // TODO(b/68493863): Change to use SetOperandLayout().
500         const Shape send_buffer_shape = instruction->operand(0)->shape();
501         TF_RET_CHECK(send_buffer_shape.IsArray());
502         Shape new_buffer_shape =
503             get_channel_constraints(instruction)
504                 ->LayoutShapeForChannel(send_buffer_shape,
505                                         *instruction->channel_id());
506         TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
507             new_buffer_shape, instruction->operand(0)));
508       } else {
509         const Shape recv_buffer_shape =
510             ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
511         TF_RET_CHECK(recv_buffer_shape.IsArray());
512         TF_ASSIGN_OR_RETURN(
513             const LogicalBuffer* buffer,
514             constraints->points_to_analysis().GetBufferDefinedAt(instruction,
515                                                                  {0}));
516         Shape new_shape =
517             get_channel_constraints(instruction)
518                 ->LayoutShapeForChannel(recv_buffer_shape,
519                                         *instruction->channel_id());
520         TF_RETURN_IF_ERROR(
521             constraints->SetBufferLayout(new_shape.layout(), *buffer));
522       }
523     } else if (IsLayoutConstrainedAllReduce(instruction)) {
524       TF_RETURN_IF_ERROR(
525           constraints->SetInstructionLayout(instruction->shape(), instruction));
526     } else if (instruction->IsCrossModuleAllReduce()) {
527       CHECK(get_channel_constraints(instruction))
528           << "Multi-module layout assignment requires ChannelLayoutConstraints";
529       int64 channel_id = instruction->channel_id().value();
530       if (!get_channel_constraints(instruction)
531                ->IsChannelConstrained(channel_id)) {
532         continue;
533       }
534       // TODO(b/68493863): Change to use SetOperandLayout().
535       const Shape& buffer_shape = instruction->operand(0)->shape();
536       TF_RET_CHECK(buffer_shape.IsArray());
537       Shape new_buffer_shape =
538           get_channel_constraints(instruction)
539               ->LayoutShapeForChannel(buffer_shape, channel_id);
540       TF_RETURN_IF_ERROR(
541           constraints->SetInstructionLayout(new_buffer_shape, instruction));
542     }
543   }
544 
545   // Constrain layouts of instructions which call computations which have
546   // already been assigned layouts. Instructions which call computations in a
547   // parallel element-wise context (eg, map or reduce) do not need layout
548   // constraints because they operate on scalars.
549   for (auto* instruction : computation->instructions()) {
550     if (instruction->opcode() == HloOpcode::kCall) {
551       // kCall instruction operands and output must match the ComputationLayout
552       // of the called computation.
553       const ComputationLayout& called_computation_layout =
554           FindOrDie(computation_layouts_, instruction->to_apply());
555       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
556           called_computation_layout.result_layout().shape(), instruction));
557       TF_RET_CHECK(instruction->operand_count() ==
558                    called_computation_layout.parameter_count());
559       for (int64 i = 0; i < instruction->operand_count(); ++i) {
560         TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
561             called_computation_layout.parameter_layout(i).shape(), instruction,
562             i));
563       }
564     } else if (instruction->opcode() == HloOpcode::kWhile) {
565       // Layout of input and output of kWhile instruction must be equal and must
566       // match both input and output of body computation. Also, the input of
567       // condition computation must match kWhile layout.
568       HloComputation* body = instruction->while_body();
569       HloComputation* condition = instruction->while_condition();
570       const HloInstruction* init = instruction->operand(0);
571       ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
572       ComputationLayout& condition_layout =
573           FindOrDie(computation_layouts_, condition);
574 
575       // Check a few invariants irrespective of layout.
576       CHECK_EQ(1, instruction->operand_count());
577       CHECK_EQ(1, body->num_parameters());
578       CHECK_EQ(1, condition->num_parameters());
579       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
580                                    body_layout.parameter_shape(0)));
581       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
582                                    condition_layout.parameter_shape(0)));
583       DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
584 
585       if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
586         VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
587                 << " while=" << instruction->name()
588                 << " shape=" << body_layout.result_layout().ToString();
589         *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
590       }
591       if (condition_layout.parameter_layout(0) !=
592           body_layout.parameter_layout(0)) {
593         VLOG(2) << "Reset %while condition parameter layout: cond="
594                 << condition->name() << " while=" << instruction->name()
595                 << " shape=" << body_layout.parameter_layout(0).ToString();
596         *condition_layout.mutable_parameter_layout(0) =
597             body_layout.parameter_layout(0);
598       }
599 
600       // Constrain the output and the operand of the while instruction to match
601       // the computations.
602       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
603           body_layout.result_shape(), instruction, 0));
604       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
605           body_layout.result_shape(), instruction));
606     } else if (instruction->opcode() == HloOpcode::kConditional) {
607       // Find the conditional branch with the most instructions and force all
608       // other computations to match that layout. A potentially better decision
609       // could count the number FLOPs or how constrained the layouts are.
610       int64 largest_branch = 0;
611       int64 largest_instruction_count =
612           instruction->branch_computation(0)->instruction_count();
613       for (int j = 1; j < instruction->branch_count(); ++j) {
614         const int64 instruction_count =
615             instruction->branch_computation(j)->instruction_count();
616         if (instruction_count > largest_instruction_count) {
617           largest_branch = j;
618           largest_instruction_count = instruction_count;
619         }
620       }
621       ComputationLayout& best_branch_computation_layout =
622           FindOrDie(computation_layouts_,
623                     instruction->branch_computation(largest_branch));
624       for (int k = 0; k < instruction->branch_count(); ++k) {
625         // Visit the best branch first.
626         int j = (k + largest_branch) % instruction->branch_count();
627         TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
628         ComputationLayout& branch_computation_layout =
629             FindOrDie(computation_layouts_, instruction->branch_computation(k));
630         if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
631                 best_branch_computation_layout.result_layout().shape(),
632                 /*minor_to_major_only=*/true)) {
633           computation_layouts_.erase(instruction->branch_computation(k));
634           InsertOrDie(&conditional_mismatch_,
635                       instruction->branch_computation(k),
636                       best_branch_computation_layout);
637         } else {
638           TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
639               branch_computation_layout.parameter_shape(0), instruction, k + 1,
640               /*mandatory=*/true));
641         }
642       }
643       TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
644           best_branch_computation_layout.parameter_shape(0), instruction,
645           largest_branch + 1,
646           /*mandatory=*/true));
647       TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
648           best_branch_computation_layout.result_shape(), instruction));
649     }
650   }
651   // Finally set the result layout to match ComputationLayout, if there is one.
652   if (conditional_mismatch_.count(computation) > 0) {
653     TF_RETURN_IF_ERROR(constraints->SetResultLayout(
654         FindOrDie(conditional_mismatch_, computation).result_layout().shape()));
655   } else if (computation_layout != nullptr) {
656     const ShapeLayout& result_layout = computation_layout->result_layout();
657     if (result_layout.LayoutIsSet()) {
658       TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
659     }
660   }
661   return Status::OK();
662 }
663 
664 namespace {
665 
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)666 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
667   return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
668 }
669 
670 // The operands of a call must match the layouts of parameters in the
671 // ComputationLayout, and the call instruction itself must match the result
672 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)673 Status CheckCallLayout(HloInstruction* call,
674                        const ComputationLayout& computation_layout) {
675   HloComputation* computation = call->to_apply();
676   TF_RET_CHECK(computation->num_parameters() == call->operand_count());
677   for (int64 i = 0; i < computation->num_parameters(); ++i) {
678     TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
679         call->operand(i)->shape(), /*minor_to_major_only=*/true));
680   }
681   TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
682       call->shape(), /*minor_to_major_only=*/true));
683   return Status::OK();
684 }
685 
686 // Operands of layout-constrained custom calls must match the expected
687 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)688 Status CheckCustomCallLayout(HloInstruction* instruction) {
689   if (IsLayoutConstrainedCustomCall(instruction)) {
690     const HloCustomCallInstruction* custom_call =
691         DynCast<HloCustomCallInstruction>(instruction);
692     for (int64 i = 0; i < custom_call->operand_count(); ++i) {
693       TF_RET_CHECK(
694           LayoutsInShapesEqual(custom_call->operand(i)->shape(),
695                                custom_call->operand_shapes_with_layout()[i]));
696     }
697   }
698   return Status::OK();
699 }
700 
701 // For a while instruction, all the following layouts must be the same:
702 //   (1) init operand
703 //   (2) condition computation parameter
704 //   (3) body computation parameter
705 //   (4) body computation result
706 //   (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)707 Status CheckWhileLayout(HloInstruction* while_inst,
708                         const ComputationLayout& condition_computation_layout,
709                         const ComputationLayout& body_computation_layout) {
710   auto init_shape = while_inst->operand(0)->shape();
711   TF_RET_CHECK(
712       condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
713           init_shape, /*minor_to_major_only=*/true));
714   TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
715       init_shape, /*minor_to_major_only=*/true));
716   TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
717       init_shape, /*minor_to_major_only=*/true));
718   TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
719   return Status::OK();
720 }
721 
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)722 Status CheckConditionalLayout(
723     HloInstruction* instruction,
724     absl::Span<const ComputationLayout> branch_computation_layouts) {
725   for (int j = 0; j < instruction->branch_count(); ++j) {
726     const HloInstruction* branch_operand = instruction->operand(j + 1);
727     TF_RET_CHECK(
728         branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
729             branch_computation_layouts[j].result_layout().shape(),
730             /*minor_to_major_only=*/true));
731     TF_RET_CHECK(
732         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
733             instruction->shape(), /*minor_to_major_only=*/true));
734     TF_RET_CHECK(
735         branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
736             instruction->branch_computation(j)->root_instruction()->shape(),
737             /*minor_to_major_only=*/true));
738     TF_RET_CHECK(
739         branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
740             branch_operand->shape(), /*minor_to_major_only=*/true));
741   }
742   return Status::OK();
743 }
744 
745 // Fusion parameters must match the layout of the fusion instructions operands,
746 // and the root of the fusion expression must match the layout of the fusion
747 // instruction.
CheckFusionLayout(HloInstruction * fusion)748 Status CheckFusionLayout(HloInstruction* fusion) {
749   TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
750 
751   TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
752                                     fusion->fused_expression_root()->shape()));
753   for (int64 i = 0; i < fusion->operand_count(); ++i) {
754     TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
755                                       fusion->operand(i)->shape()));
756   }
757   return Status::OK();
758 }
759 
760 // The layout of a parameter must match the respective layout in the
761 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)762 Status CheckParameterLayout(HloInstruction* parameter,
763                             const ComputationLayout& computation_layout) {
764   const ShapeLayout& parameter_layout =
765       computation_layout.parameter_layout(parameter->parameter_number());
766   return ShapeUtil::ForEachSubshapeWithStatus(
767       parameter_layout.shape(),
768       [&](const Shape& subshape, const ShapeIndex& shape_index) {
769         if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
770             !subshape.has_layout()) {
771           return Status::OK();
772         }
773         if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
774                 subshape,
775                 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
776           return InternalError(
777               "parameter instruction %s does not match layout of computation "
778               "shape: %s",
779               parameter->ToString(), parameter_layout.ToString());
780         }
781         return Status::OK();
782       });
783 }
784 
785 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)786 Status CheckConstantLayout(HloInstruction* constant) {
787   if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
788     return InternalError(
789         "constant instruction %s does not match the layout of its literal %s",
790         constant->ToString(),
791         ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
792   }
793   return Status::OK();
794 }
795 
796 }  // namespace
797 
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)798 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
799     const Shape& shape_with_layout, HloInstruction* instruction) {
800   TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
801   DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
802       << ShapeUtil::HumanString(shape_with_layout) << " "
803       << ShapeUtil::HumanString(instruction->shape())
804       << " instruction: " << instruction->ToString();
805 
806   if (instruction->shape().IsTuple()) {
807     // Copy tuple elements which have differing layouts.
808     std::vector<HloInstruction*> element_copies;
809     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
810          ++i) {
811       const Shape& target_shape =
812           ShapeUtil::GetSubshape(shape_with_layout, {i});
813       const Shape& instr_shape =
814           ShapeUtil::GetSubshape(instruction->shape(), {i});
815       HloInstruction* gte = instruction->parent()->AddInstruction(
816           HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
817 
818       if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
819                                                     instr_shape)) {
820         // Shapes and layouts are equal, no need to copy.
821         element_copies.push_back(gte);
822       } else {
823         SetupCopiedInstruction(*instruction, gte, {i});
824         // Recurse to copy each element.
825         TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
826                             CreateCopyWithNewLayout(target_shape, gte));
827         element_copies.push_back(element_copy);
828       }
829     }
830     // Gather element copies into a tuple with a new Tuple instruction.
831     HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
832         HloInstruction::CreateTuple(element_copies));
833     SetupCopiedInstruction(*instruction, tuple_copy, {});
834     LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
835     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
836         shape_with_layout, tuple_copy->mutable_shape()));
837     return tuple_copy;
838   } else if (instruction->shape().IsArray()) {
839     HloInstruction* copy =
840         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
841             instruction->shape(), HloOpcode::kCopy, instruction));
842     RegisterAddedCopy(copy);
843     SetupCopiedInstruction(*instruction, copy, {});
844     LayoutUtil::ClearLayout(copy->mutable_shape());
845     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
846         shape_with_layout, copy->mutable_shape()));
847 
848     return copy;
849   } else {
850     return FailedPrecondition(
851         "Can only copy array and tuple shaped instructions");
852   }
853 }
854 
855 // Creates a copy of the given operand if the operand's layout does not match
856 // the given layout. This copy replaces the use in the given instruction. Tuple
857 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64 operand_no)858 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
859     const ShapeLayout& operand_layout, HloInstruction* instruction,
860     int64 operand_no) {
861   HloInstruction* operand = instruction->mutable_operand(operand_no);
862   TF_RET_CHECK(operand_layout.LayoutIsSet());
863   TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
864 
865   if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
866                                                 operand->shape())) {
867     VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
868             << instruction->ToString();
869     // Operand layout already matches our constraint. Nothing to do.
870     return Status::OK();
871   }
872   VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
873           << operand_layout.ToString() << " in " << instruction->ToString();
874 
875   // If the operand is only used by a conditional, do the copy inside the branch
876   // to avoid overhead for other branches.
877   if (instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
878       instruction->operand(operand_no)->user_count() == 1) {
879     auto branch_comp = instruction->branch_computation(operand_no - 1);
880     auto param = branch_comp->parameter_instruction(0);
881     *param->mutable_shape() = operand->shape();
882     auto param_users = param->users();
883     TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
884                         CreateCopyWithNewLayout(operand_layout.shape(), param));
885     for (auto user : param_users) {
886       TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
887     }
888     VLOG(4) << "New copy of " << operand->ToString() << " is "
889             << param_copy->ToString();
890     if (param == branch_comp->root_instruction()) {
891       branch_comp->set_root_instruction(param_copy,
892                                         /*accept_different_shape=*/true);
893     }
894     *FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) =
895         ShapeLayout(operand->shape());
896     return Status::OK();
897   }
898 
899   TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
900                       CreateCopyWithNewLayout(operand_layout.shape(), operand));
901 
902   VLOG(4) << "New copy of " << operand->ToString() << " is "
903           << operand_copy->ToString();
904   return instruction->ReplaceOperandWith(operand_no, operand_copy);
905 }
906 
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)907 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
908                                               HloInstruction* copy,
909                                               const ShapeIndex& index) {
910   if (instruction.has_sharding()) {
911     // If the index is empty, we want to copy the whole sharding, in case the
912     // sharding is a tuple sharding.
913     HloSharding sharding =
914         !index.empty() && instruction.sharding().IsTuple()
915             ? instruction.sharding().GetSubSharding(instruction.shape(), index)
916             : instruction.sharding();
917     // We propagate the sharding to the copied instruction only if it is a
918     // special sharding, like tiled ones.
919     // Otherwise it is preferable to leave the new instruction without device,
920     // and let the automatic device placer to choose the best location.
921     auto device = sharding.UniqueDevice();
922     if (!device || HloSharding::IsReservedDevice(*device)) {
923       copy->set_sharding(sharding);
924     }
925   }
926   copy->set_metadata(instruction.metadata());
927 }
928 
CheckLayouts(HloModule * module)929 Status LayoutAssignment::CheckLayouts(HloModule* module) {
930   TF_ASSIGN_OR_RETURN(auto points_to_analysis,
931                       TuplePointsToAnalysis::Run(module));
932   for (auto* computation : module->MakeNonfusionComputations()) {
933     for (auto* instruction : computation->instructions()) {
934       // Verify every instruction has a layout and the layout is valid for the
935       // shape.
936       TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
937       TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
938 
939       // Use points-to analysis to verify that every subshape element in the
940       // output of the instruction matches the layout of the logical buffer
941       // which could be the source of the subshape value.
942       const PointsToSet& points_to_set =
943           points_to_analysis->GetPointsToSet(instruction);
944       TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
945           [&instruction](ShapeIndex index,
946                          const PointsToSet::BufferList& buffers) -> Status {
947             if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
948               const Shape& instruction_subshape =
949                   ShapeUtil::GetSubshape(instruction->shape(), index);
950               for (const LogicalBuffer* buffer : buffers) {
951                 if (!Shape::Equal()
952                          .IgnoreDynamicDimension()
953                          .MinorToMajorOnlyInLayout()(instruction_subshape,
954                                                      buffer->shape())) {
955                   return InternalError(
956                       "Layout of instruction %s at index {%s} does not match "
957                       "source LogicalBuffer %s: %s vs %s",
958                       instruction->name(), absl::StrJoin(index, ","),
959                       buffer->ToString(),
960                       ShapeUtil::HumanStringWithLayout(instruction_subshape),
961                       ShapeUtil::HumanStringWithLayout(buffer->shape()));
962                 }
963               }
964             }
965             return Status::OK();
966           }));
967 
968       // Verify instructions that have special layout constraints.
969       switch (instruction->opcode()) {
970         case HloOpcode::kCall:
971           TF_RETURN_IF_ERROR(CheckCallLayout(
972               instruction,
973               FindOrDie(computation_layouts_, instruction->to_apply())));
974           break;
975         case HloOpcode::kCustomCall:
976           TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
977           break;
978         case HloOpcode::kFusion:
979           TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
980           break;
981         case HloOpcode::kParameter:
982           TF_RETURN_IF_ERROR(CheckParameterLayout(
983               instruction,
984               FindOrDie(computation_layouts_, instruction->parent())));
985           break;
986         case HloOpcode::kConstant:
987           TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
988           break;
989         case HloOpcode::kWhile:
990           TF_RETURN_IF_ERROR(CheckWhileLayout(
991               instruction,
992               FindOrDie(computation_layouts_, instruction->while_condition()),
993               FindOrDie(computation_layouts_, instruction->while_body())));
994           break;
995         case HloOpcode::kConditional: {
996           std::vector<ComputationLayout> branch_computation_layouts;
997           for (auto branch_computation : instruction->branch_computations()) {
998             branch_computation_layouts.emplace_back(
999                 FindOrDie(computation_layouts_, branch_computation));
1000           }
1001           TF_RETURN_IF_ERROR(CheckConditionalLayout(
1002               instruction, absl::MakeSpan(branch_computation_layouts)));
1003           break;
1004         }
1005         default:
1006           break;
1007       }
1008     }
1009   }
1010   // Finally verify the result layout, if set, matches the layout of the entry
1011   // computation root.
1012   const ShapeLayout& result_layout =
1013       FindOrDie(computation_layouts_, module->entry_computation())
1014           .result_layout();
1015   if (result_layout.LayoutIsSet()) {
1016     TF_RET_CHECK(
1017         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1018             module->result_shape(), result_layout.shape()));
1019   }
1020   return Status::OK();
1021 }
1022 
LayoutAssignment(ComputationLayout * entry_computation_layout,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func,ChannelLayoutConstraints * channel_constraints)1023 LayoutAssignment::LayoutAssignment(
1024     ComputationLayout* entry_computation_layout,
1025     std::function<bool(const HloInstruction*)>
1026         instruction_can_change_layout_func,
1027     ChannelLayoutConstraints* channel_constraints)
1028     : entry_computation_layout_(entry_computation_layout),
1029 
1030       saved_entry_computation_layout_(*entry_computation_layout),
1031       channel_layout_constraints_(channel_constraints),
1032       instruction_can_change_layout_func_(
1033           std::move(instruction_can_change_layout_func)) {
1034   if (channel_layout_constraints_ != nullptr) {
1035     // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1036     // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1037     channel_constraints_ = *channel_layout_constraints_;
1038   }
1039   VLOG(1) << "Entry computation layout given to layout assignment: "
1040           << entry_computation_layout_->ToString();
1041 }
1042 
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64 operand_no)1043 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1044     const Layout& output_layout, const HloInstruction* instruction,
1045     int64 operand_no) {
1046   const HloInstruction* operand = instruction->operand(operand_no);
1047   CHECK(instruction->shape().IsArray());
1048   CHECK(operand->shape().IsArray());
1049   if (!ShapeUtil::IsScalar(operand->shape()) &&
1050       operand->shape().rank() == instruction->shape().rank() &&
1051       !instruction_can_change_layout_func_(instruction)) {
1052     // Propagate the result layout to the operand layout if the instruction
1053     // requires the same layout out for the result and the operand.
1054     //
1055     // For elementwise operations, using the same layout for the operands and
1056     // the result also has the following benefits:
1057     // 1) the elementwise operation can reuse its operand's buffer, and
1058     // 2) the input and output elements can reuse the same linear index.
1059     return absl::make_unique<Layout>(output_layout);
1060   }
1061 
1062   if (instruction->opcode() == HloOpcode::kReshape) {
1063     // Prefer the operand layout that makes the reshape an bitcast. If any
1064     // dimension bound is 1 in the operand shape, there may be several such
1065     // layouts. So if 'output_layout' is the default layout, try if the
1066     // reshape is a bitcast when using the same layout. This may avoid copy
1067     // operations. For similar reasons, if the operand and output have the same
1068     // rank, try to match the operand's layout to the output.
1069     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1070         ShapeUtil::TrueRank(instruction->shape()) == 1) {
1071       // Don't assign a layout in case of R1 -> effective R1 reshape.
1072       return nullptr;
1073     }
1074 
1075     const Shape& output_shape = instruction->shape();
1076     Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1077         output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1078         LayoutUtil::MinorToMajor(output_layout));
1079     Shape operand_shape = operand->shape();
1080     *operand_shape.mutable_layout() =
1081         LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1082     auto aligned_operand_shape =
1083         ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1084     if (aligned_operand_shape) {
1085       auto operand_layout = aligned_operand_shape.value().layout();
1086       TF_CHECK_OK(
1087           LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1088       return absl::make_unique<Layout>(operand_layout);
1089     }
1090   }
1091 
1092   if (instruction->opcode() == HloOpcode::kTranspose) {
1093     // Pick the operand layout that makes the transpose a bitcast.
1094     int64 rank = instruction->shape().rank();
1095     std::vector<int64> new_minor_to_major(rank);
1096     for (int64 i = 0; i < rank; ++i) {
1097       int64 output_dim = LayoutUtil::Minor(output_layout, i);
1098       int64 operand_dim = instruction->dimensions(output_dim);
1099       new_minor_to_major[i] = operand_dim;
1100     }
1101     Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1102     TF_CHECK_OK(
1103         LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1104     return absl::make_unique<Layout>(operand_layout);
1105   }
1106 
1107   return nullptr;
1108 }
1109 
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64 operand_no)1110 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1111     const Layout& operand_layout, const HloInstruction* user,
1112     int64 operand_no) {
1113   const HloInstruction* operand = user->operand(operand_no);
1114 
1115   CHECK(user->shape().IsArray() && operand->shape().IsArray());
1116 
1117   if (!ShapeUtil::IsScalar(operand->shape()) &&
1118       operand->shape().rank() == user->shape().rank() &&
1119       !instruction_can_change_layout_func_(user)) {
1120     // Assign users the same layout as the operand.
1121     return absl::make_unique<Layout>(operand_layout);
1122   }
1123 
1124   if (user->opcode() == HloOpcode::kReshape) {
1125     // Prefer the user layout that makes the reshape an bitcast. If any
1126     // dimension bound is 1 in the user shape, there may be several such
1127     // layouts. So if 'operand_layout' is the default layout, try if the
1128     // reshape is a bitcast when using the same layout. This may avoid copy
1129     // operations. For similar reasons, if the operand and output have the same
1130     // rank, try to match the outputs's layout to the operand.
1131     if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1132         ShapeUtil::TrueRank(user->shape()) == 1) {
1133       // Don't assign a layout in case of R1 -> effective R1 reshape.
1134       return nullptr;
1135     }
1136     Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1137         operand->shape().element_type(),
1138         AsInt64Slice(operand->shape().dimensions()),
1139         LayoutUtil::MinorToMajor(operand_layout));
1140     Shape output_shape = user->shape();
1141     *output_shape.mutable_layout() =
1142         LayoutUtil::GetDefaultLayoutForShape(output_shape);
1143     auto aligned_user_shape =
1144         ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1145     if (aligned_user_shape) {
1146       auto user_layout = aligned_user_shape.value().layout();
1147       TF_CHECK_OK(
1148           LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1149       return absl::make_unique<Layout>(user_layout);
1150     }
1151   }
1152 
1153   if (user->opcode() == HloOpcode::kTranspose) {
1154     // Pick the user layout that makes the transpose a bitcast.
1155     int64 rank = user->shape().rank();
1156     std::vector<int64> new_minor_to_major(rank);
1157     auto inverse_dimensions = InversePermutation(user->dimensions());
1158     for (int64 i = 0; i < rank; ++i) {
1159       int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
1160       int64 user_dim = inverse_dimensions[operand_dim];
1161       new_minor_to_major[i] = user_dim;
1162     }
1163     Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1164     TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1165     return absl::make_unique<Layout>(user_layout);
1166   }
1167 
1168   return nullptr;
1169 }
1170 
PropagateConstraints(LayoutConstraints * constraints)1171 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1172   // Gathers all initial constraints in a worklist and propagates them in
1173   // depth-first order. DFS order seems to be better than BFS because a
1174   // constraint is propagated as far as possible before propagating unrelated
1175   // constraints which makes it less likely that conflicting constraints will be
1176   // propagated to instructions. However, we should experiment with other orders
1177   // too.
1178   std::deque<const LayoutConstraint*> worklist;
1179 
1180   // Lambda for moving newly added constraints to the worklist.
1181   auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1182     // Add constraints to the front of the deque for DFS ordering.
1183     for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1184       if (constraint->dfs()) {
1185         worklist.push_front(constraint);
1186       } else {
1187         worklist.push_back(constraint);
1188       }
1189     }
1190   };
1191   add_new_constraints_to_worklist();
1192 
1193   while (!worklist.empty()) {
1194     const LayoutConstraint* layout_constraint = worklist.front();
1195     worklist.pop_front();
1196     VLOG(2) << "Propagating " << layout_constraint->ToString()
1197             << " to its neighbors.";
1198     if (auto* buffer_constraint =
1199             dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1200       TF_RETURN_IF_ERROR(
1201           PropagateBufferConstraint(*buffer_constraint, constraints));
1202     } else if (auto* operand_constraint =
1203                    dynamic_cast<const OperandLayoutConstraint*>(
1204                        layout_constraint)) {
1205       TF_RETURN_IF_ERROR(
1206           PropagateOperandConstraint(*operand_constraint, constraints));
1207     } else if (auto* result_constraint =
1208                    dynamic_cast<const ResultLayoutConstraint*>(
1209                        layout_constraint)) {
1210       TF_RETURN_IF_ERROR(
1211           PropagateResultConstraint(*result_constraint, constraints));
1212     } else {
1213       LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1214     }
1215 
1216     add_new_constraints_to_worklist();
1217   }
1218   return Status::OK();
1219 }
1220 
1221 namespace {
1222 
1223 // Returns a vector containing all array-shaped uses (instruction and operand
1224 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1225 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1226     const LogicalBuffer& buffer,
1227     const TuplePointsToAnalysis& points_to_analysis) {
1228   CHECK(buffer.IsArray());
1229   std::vector<std::pair<const HloInstruction*, int64>> uses;
1230   for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1231     if (!buffer_alias.instruction()->shape().IsArray()) {
1232       continue;
1233     }
1234     // This alias must be the top-level (index == {}) of the instruction's
1235     // result because the instruction produces an array.
1236     CHECK(buffer_alias.index().empty());
1237 
1238     // Add all uses of the instruction's output.
1239     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1240       for (int64 operand_no :
1241            user->OperandIndices(buffer_alias.instruction())) {
1242         uses.emplace_back(user, operand_no);
1243       }
1244     }
1245   }
1246   return uses;
1247 }
1248 
1249 }  // namespace
1250 
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1251 Status LayoutAssignment::PropagateUseConstraintToDefs(
1252     const ShapeLayout& shape_layout, const HloInstruction* instruction,
1253     LayoutConstraints* constraints) {
1254   // Try to set all logical buffers which may be sources of the given operand to
1255   // match the given layout.
1256   const PointsToSet& points_to_set =
1257       constraints->points_to_analysis().GetPointsToSet(instruction);
1258   return points_to_set.ForEachElementWithStatus(
1259       [&shape_layout, constraints](
1260           const ShapeIndex& index,
1261           const PointsToSet::BufferList& buffers) -> Status {
1262         if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1263           for (const LogicalBuffer* buffer : buffers) {
1264             if (constraints->BufferLayout(*buffer) == nullptr &&
1265                 buffer->shape().IsArray()) {
1266               TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1267                   ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1268                   *buffer, /*mandatory=*/true));
1269             }
1270           }
1271         }
1272         return Status::OK();
1273       });
1274 }
1275 
1276 namespace {
1277 // A transpose or a reshape that only changes trivial dimensions have meaningful
1278 // layouts that are valuable to propagate in a depthfirst manner to avoid
1279 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo,bool forward_propagation=true)1280 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
1281                                           bool forward_propagation = true) {
1282   switch (hlo.opcode()) {
1283     case HloOpcode::kFusion:
1284       return hlo.IsCustomFusion();
1285     case HloOpcode::kGather:
1286       return true;
1287     case HloOpcode::kReshape:
1288       return hlo.operand(0)->shape().rank() == 1 ||
1289              (forward_propagation &&
1290               std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
1291     case HloOpcode::kScatter:
1292     case HloOpcode::kTranspose:
1293       return true;
1294     default:
1295       return false;
1296   }
1297 }
1298 
1299 }  // namespace
1300 
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1301 Status LayoutAssignment::PropagateOperandConstraint(
1302     const OperandLayoutConstraint& operand_constraint,
1303     LayoutConstraints* constraints) {
1304   // Try to set the layout of the logical buffers in the given operand to match
1305   // the constrained layout. This avoids copies.
1306   TF_RETURN_IF_ERROR(
1307       PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1308                                    operand_constraint.operand(), constraints));
1309 
1310   // For array-shaped operands and user instructions try to pick a minimum cost
1311   // layout. For example, if the operand of an elementwise instruction is
1312   // constrained to a certain layout we want the output of the instruction to
1313   // have the same layout.
1314   //
1315   // If the user is not array-shaped, we still want to propagate the layout
1316   // to siblings if the instruction can't change layout. This is to represent
1317   // the information that non-layout-changing instructions should have the same
1318   // layout for the operands with the same ranks.
1319   const HloInstruction* operand = operand_constraint.operand();
1320   const HloInstruction* user = operand_constraint.instruction();
1321   if (!operand->shape().IsArray()) {
1322     return Status::OK();
1323   }
1324 
1325   if (user->opcode() == HloOpcode::kAllReduce) {
1326     const auto shape_index =
1327         user->operand_count() == 1
1328             ? ShapeIndex()
1329             : ShapeIndex({operand_constraint.operand_no()});
1330     TF_ASSIGN_OR_RETURN(const LogicalBuffer* buffer,
1331                         constraints->points_to_analysis().GetBufferDefinedAt(
1332                             user, shape_index));
1333     const BufferLayoutConstraint* constraint =
1334         constraints->GetBufferLayoutConstraint(*buffer);
1335     if (constraint == nullptr) {
1336       TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1337           operand_constraint.shape_layout().layout(), *buffer,
1338           /*mandatory=*/false));
1339     }
1340   }
1341   if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
1342     return Status::OK();
1343   }
1344 
1345   // Only try to choose a low cost layout if the instruction 'user' defines its
1346   // output (ie, doesn't forward a buffer from elsewhere).
1347   if (constraints->OperandBufferForwarded(user,
1348                                           operand_constraint.operand_no())) {
1349     return Status::OK();
1350   }
1351 
1352   int64 operand_rank = operand->shape().rank();
1353   if (operand_rank <= 1) {
1354     return Status::OK();
1355   }
1356 
1357   // Propagate layouts between operands of the same instruction. This is a
1358   // constraint on non-layout-changing instructions.
1359   if (!instruction_can_change_layout_func_(user)) {
1360     // Make sure all siblings have the same layout as the operand.
1361     for (int64 operand_no = 0; operand_no < user->operand_count();
1362          ++operand_no) {
1363       if (user->operand(operand_no) == operand) {
1364         continue;
1365       }
1366       const HloInstruction* sibling = user->operand(operand_no);
1367       const int64 sibling_rank = sibling->shape().rank();
1368       if (sibling_rank <= 1) {
1369         continue;
1370       }
1371       if (operand_rank != sibling_rank) {
1372         continue;
1373       }
1374       const OperandLayoutConstraint* constraint =
1375           constraints->GetOperandLayoutConstraint(user, operand_no);
1376       if (constraint != nullptr) {
1377         // Due to the DFS of the propagation we can end up here when operand_no
1378         // has a layout set that hasn't been propagated yet (is still on the
1379         // stack of layouts to propagate).
1380         // We can continue here and leave the operands with different layouts,
1381         // as we will either:
1382         // - overwrite the current operand when the DFS gets back to propagating
1383         //   operand(operand_no) to its siblings
1384         // - overwrite operand(operand_no)'s layout with a mandatory layout if
1385         //   we continue to propagate our layout to the result, and then
1386         //   backwards into all operands (if the result is an array of rank > 1)
1387         continue;
1388       }
1389       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1390           operand_constraint.shape_layout().layout(), user, operand_no,
1391           /*mandatory=*/false));
1392     }
1393     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1394         user->shape(),
1395         [&](const Shape& subshape, const ShapeIndex& shape_index) {
1396           if (subshape.IsTuple()) {
1397             return Status::OK();
1398           }
1399           if (subshape.rank() <= 1) {
1400             return Status::OK();
1401           }
1402 
1403           // Assign the right layout to input fusion of higher rank reduce
1404           // operations.
1405           if (subshape.rank() != operand->shape().rank()) {
1406             return Status::OK();
1407           }
1408           // TODO(b/67641796): Are there cases except fusion that use this code
1409           // path?
1410           TF_ASSIGN_OR_RETURN(
1411               const LogicalBuffer* buffer,
1412               constraints->points_to_analysis().GetBufferDefinedAt(
1413                   user, shape_index));
1414           // Make sure the output has the same layout as the operand.
1415           const BufferLayoutConstraint* constraint =
1416               constraints->GetBufferLayoutConstraint(*buffer);
1417           // If we already have a constraint for the buffer it was assigned but
1418           // hasn't propagated yet. This can happen with diamond-shaped graphs
1419           // where one path is first evaluated in depth-first order (we're here)
1420           // and the other path is propagated later. We don't set the layout
1421           // here as it will always be overwritten later.
1422           if (constraint == nullptr) {
1423             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1424                 operand_constraint.shape_layout().layout(), *buffer,
1425                 /*mandatory=*/false));
1426           }
1427           return Status::OK();
1428         }));
1429     return Status::OK();
1430   }
1431   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1432       user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1433         if (subshape.IsTuple()) {
1434           return Status::OK();
1435         }
1436         if (subshape.rank() <= 1) {
1437           return Status::OK();
1438         }
1439         TF_ASSIGN_OR_RETURN(
1440             const LogicalBuffer* buffer,
1441             constraints->points_to_analysis().GetBufferDefinedAt(user,
1442                                                                  shape_index));
1443         if (constraints->BufferLayout(*buffer) == nullptr ||
1444             !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1445           std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1446               operand_constraint.shape_layout().layout(), user,
1447               operand_constraint.operand_no());
1448           if (layout != nullptr) {
1449             TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1450                 *layout, *buffer,
1451                 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1452                 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1453           }
1454         }
1455         return Status::OK();
1456       }));
1457   return Status::OK();
1458 }
1459 
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1460 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1461     const BufferLayoutConstraint& buffer_constraint,
1462     LayoutConstraints* constraints) {
1463   VLOG(5) << "PropagateBufferConstraintToOperands: "
1464           << buffer_constraint.ToString();
1465   const LogicalBuffer& buffer = buffer_constraint.buffer();
1466 
1467   const HloInstruction* instruction = buffer.instruction();
1468   if (IsAtMostRank1(instruction->shape())) {
1469     return Status::OK();
1470   }
1471 
1472   if (instruction->opcode() == HloOpcode::kAllReduce) {
1473     TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1474         buffer_constraint.layout(), instruction,
1475         instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1476         /*mandatory=*/true));
1477     return Status::OK();
1478   }
1479   for (int64 operand_no = 0; operand_no < instruction->operand_count();
1480        ++operand_no) {
1481     const HloInstruction* operand = instruction->operand(operand_no);
1482     if (IsAtMostRank1(operand->shape())) {
1483       continue;
1484     }
1485     if (!instruction_can_change_layout_func_(instruction)) {
1486       // Copy the layout to the operand.
1487       if (buffer.IsArray() && operand->shape().IsArray() &&
1488           operand->shape().rank() ==
1489               LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1490         TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1491             buffer_constraint.layout(), instruction, operand_no,
1492             /*mandatory=*/true));
1493       }
1494     } else {
1495       if (!buffer.IsTopLevel() ||
1496           !instruction->operand(operand_no)->shape().IsArray()) {
1497         continue;  // Don't touch buffers that are internal to a tuple.
1498       }
1499       VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1500               << instruction->ToShortString();
1501       // Assign a layout if there is no constraint already.
1502       const OperandLayoutConstraint* constraint =
1503           constraints->GetOperandLayoutConstraint(instruction, operand_no);
1504       if (constraint == nullptr || !constraint->mandatory()) {
1505         std::unique_ptr<Layout> operand_layout =
1506             ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1507                                                 instruction, operand_no);
1508         if (operand_layout != nullptr) {
1509           TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1510               *operand_layout, instruction, operand_no, /*mandatory=*/false,
1511               /*dfs=*/
1512               InstructionShouldPropagateDepthFirst(
1513                   *instruction, /*forward_propagation=*/false)));
1514         }
1515       } else {
1516         VLOG(6) << "Operand already has a constraint "
1517                 << constraint->ToString();
1518       }
1519     }
1520   }
1521   return Status::OK();
1522 }
1523 
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1524 Status LayoutAssignment::PropagateBufferConstraint(
1525     const BufferLayoutConstraint& buffer_constraint,
1526     LayoutConstraints* constraints) {
1527   // Only propagate array layouts.
1528   const LogicalBuffer& buffer = buffer_constraint.buffer();
1529   if (!buffer.IsArray()) {
1530     return Status::OK();
1531   }
1532   TF_RETURN_IF_ERROR(
1533       PropagateBufferConstraintToUses(buffer_constraint, constraints));
1534   return PropagateBufferConstraintToOperands(buffer_constraint, constraints);
1535 }
1536 
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1537 Status LayoutAssignment::PropagateBufferConstraintToUses(
1538     const BufferLayoutConstraint& buffer_constraint,
1539     LayoutConstraints* constraints) {
1540   const LogicalBuffer& buffer = buffer_constraint.buffer();
1541   TF_RET_CHECK(buffer.IsArray());
1542 
1543   // Propagate the layout to all array uses of the logical buffer. This skips
1544   // uses of the buffer where the buffer is the element of a tuple.
1545   for (const auto& user_operand_no :
1546        GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1547     const HloInstruction* user = user_operand_no.first;
1548     int64 operand_no = user_operand_no.second;
1549     // Only add an operand constraint if the user does not forward the buffer
1550     // because this case is not handled is SetOperandLayout.
1551     if (constraints->OperandLayout(user, operand_no) == nullptr &&
1552         !constraints->OperandBufferForwarded(user, operand_no)) {
1553       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1554           buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1555     }
1556   }
1557 
1558   // Propagate to backedges of kWhile.
1559   CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1560   if (node.caller_callsites().size() != 1) {
1561     return Status::OK();
1562   }
1563   const HloInstruction* parent = node.caller_callsites()[0].instruction();
1564   if (parent->opcode() != HloOpcode::kWhile) {
1565     return Status::OK();
1566   }
1567 
1568   for (HloInstruction* user : buffer.instruction()->users()) {
1569     if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1570       continue;
1571     }
1572     if (user->parent()->root_instruction() == user) {
1573       VLOG(3) << "Propagating layout through backedge"
1574               << buffer_constraint.layout().ToString();
1575       int64 index = user->operand_index(buffer.instruction());
1576       TF_ASSIGN_OR_RETURN(
1577           auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
1578                            user->parent()->parameter_instruction(0), {index}));
1579 
1580       TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1581           buffer_constraint.layout(), *buffer, /*mandatory=*/false));
1582     }
1583   }
1584 
1585   return Status::OK();
1586 }
1587 
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1588 Status LayoutAssignment::PropagateResultConstraint(
1589     const ResultLayoutConstraint& layout_constraint,
1590     LayoutConstraints* constraints) {
1591   // Propagate the use constraint of the root instruction up to the logical
1592   // buffers which make up the result.
1593   return PropagateUseConstraintToDefs(
1594       layout_constraint.shape_layout(),
1595       constraints->computation()->root_instruction(), constraints);
1596 }
1597 
1598 namespace {
1599 
1600 // Infers the layout of the array at the given index in the given instruction's
1601 // output using points-to analysis. Precondition: The given instruction must
1602 // not produce this array value (that is, the array is forwarded from the
1603 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1604 StatusOr<Layout> InferArrayLayout(
1605     const TuplePointsToAnalysis& points_to_analysis,
1606     HloInstruction* instruction, const ShapeIndex& index) {
1607   // This function should only be called for array shapes which don't yet have
1608   // layouts.
1609   const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1610   TF_RET_CHECK(subshape.IsArray());
1611   TF_RET_CHECK(!subshape.has_layout());
1612 
1613   // The instruction should not define the buffer at this index.
1614   TF_RET_CHECK(
1615       !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1616       << instruction->ToString();
1617 
1618   const auto& source_buffers =
1619       points_to_analysis.GetPointsToSet(instruction).element(index);
1620   TF_RET_CHECK(!source_buffers.empty());
1621 
1622   // Verify the layout is the same for every LogicalBuffer which this location
1623   // ('instruction' and 'index') points to.
1624   const Layout* first_buffer_layout = nullptr;
1625   for (const LogicalBuffer* source_buffer : source_buffers) {
1626     if (!source_buffer->shape().has_layout()) {
1627       // This should not happen because we've assigned layouts to all
1628       // instructions preceding this one.
1629       return InternalError("LogicalBuffer %s does not have a layout",
1630                            source_buffer->ToString());
1631     }
1632 
1633     if (first_buffer_layout == nullptr) {
1634       first_buffer_layout = &source_buffer->shape().layout();
1635     } else if (!Layout::Equal().MinorToMajorOnly()(
1636                    source_buffer->shape().layout(), *first_buffer_layout)) {
1637       // The points-to set is ambiguous for this index and the different source
1638       // buffers have different layouts. This case is possible in valid XLA
1639       // computations because we do not propagate BufferLayoutConstraints to all
1640       // LogicalBuffers which may alias the constrained LogicalBuffer at some
1641       // point in the computation.
1642       return FailedPrecondition(
1643           "Array at index {%s} in instruction %s aliases buffers %s "
1644           "and %s which have different layouts",
1645           absl::StrJoin(index, ","), instruction->name(),
1646           source_buffers[0]->ToString(), source_buffer->ToString());
1647     }
1648   }
1649 
1650   return *first_buffer_layout;
1651 }
1652 
1653 // For fusion instructions, set the layout of each fused parameter instruction
1654 // to match the layout of its corresponding fusion instruction operand. Also,
1655 // set the layout of the fused root to match the layout of the fusion
1656 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1657 Status SetFusionLayouts(HloInstruction* fusion) {
1658   TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1659   for (auto* fused_instruction :
1660        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1661     if (fused_instruction->opcode() == HloOpcode::kParameter) {
1662       const HloInstruction* fusion_operand =
1663           fusion->operand(fused_instruction->parameter_number());
1664       DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1665                                    fused_instruction->shape()));
1666       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1667           fusion_operand->shape(), fused_instruction->mutable_shape()));
1668     } else if (fused_instruction == fusion->fused_expression_root()) {
1669       // The layout of the root of the fused expression must match the fusion
1670       // instruction layout.
1671       DCHECK(
1672           ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1673       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1674           fusion->shape(), fused_instruction->mutable_shape()));
1675     } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1676       // A GTE inherits its layout from its operand (which should ultimately be
1677       // a parameter).
1678       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1679           fused_instruction->operand(0)->shape().tuple_shapes(
1680               fused_instruction->tuple_index()),
1681           fused_instruction->mutable_shape()));
1682     } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1683       // Give constants the layout of their literal.
1684       TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1685           fused_instruction->literal().shape(),
1686           fused_instruction->mutable_shape()));
1687     } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1688       // Nop; leave the infeed layout alone.
1689     } else if (!fusion->IsCustomFusion()) {
1690       // Other instructions don't have layouts inside of fusion nodes.
1691       // But do not clear layouts for other instructions in custom fusion nodes.
1692       LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1693     }
1694   }
1695 
1696   return Status::OK();
1697 }
1698 
1699 }  // namespace
1700 
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1701 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1702                                        HloComputation* computation) {
1703   VLOG(2) << "Assigning layouts to computation: " << computation->name();
1704   XLA_VLOG_LINES(2, computation->ToString());
1705   XLA_VLOG_LINES(2, constraints.ToString());
1706 
1707   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1708     LayoutUtil::ClearLayout(instruction->mutable_shape());
1709 
1710     // Set the layouts of the array shapes this instruction defines as indicated
1711     // by the respective BufferLayoutConstraints. Any array shapes in the output
1712     // of the instruction which are not defined by the instruction (eg, array
1713     // elements in a Tuple instruction) will be assigned below via inference.
1714     for (const LogicalBuffer* buffer :
1715          constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1716              instruction)) {
1717       if (!buffer->shape().IsArray()) {
1718         continue;
1719       }
1720 
1721       TF_RET_CHECK(buffer->instruction() == instruction);
1722       const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1723       TF_RET_CHECK(buffer_layout != nullptr);
1724 
1725       if (instruction->opcode() == HloOpcode::kConstant) {
1726         // For constants, we also need to change the layout of the internal
1727         // literal.
1728         instruction->RelayoutConstant(*buffer_layout, buffer->index());
1729       } else {
1730         Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1731             instruction->mutable_shape(), buffer->index());
1732         *buffer_subshape->mutable_layout() = *buffer_layout;
1733       }
1734     }
1735 
1736     // Any remaining layouts in the output of the instruction must be
1737     // inferrable using points-to analysis.
1738     TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1739         instruction->mutable_shape(),
1740         [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1741           if (subshape->has_layout() || !subshape->IsArray()) {
1742             return Status::OK();
1743           }
1744           // Set Layout of subshape to match layout of LogicalBuffer which
1745           // produces it.
1746           TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1747                               InferArrayLayout(constraints.points_to_analysis(),
1748                                                instruction, index));
1749           return Status::OK();
1750         }));
1751 
1752     // Create a copy of an operand if the operand instruction's layout does not
1753     // match the use constraint (OperandLayoutConstraint).
1754     for (int64 operand_no = 0; operand_no < instruction->operand_count();
1755          ++operand_no) {
1756       const ShapeLayout* operand_layout =
1757           constraints.OperandLayout(instruction, operand_no);
1758       if (operand_layout != nullptr) {
1759         TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1760                                                       instruction, operand_no));
1761       }
1762     }
1763 
1764     // Fusion instructions require some layouts to be set on fused instructions
1765     // inside the fusion instruction.
1766     if (instruction->opcode() == HloOpcode::kFusion) {
1767       TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1768     }
1769 
1770     // Execute extra verification step once the layout has been finalized.
1771     TF_RETURN_IF_ERROR(Verify(instruction));
1772 
1773     // Shape must be valid.
1774     TF_RETURN_IF_ERROR(
1775         ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1776 
1777     // Verify all layouts in the shape have been set.
1778     TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1779   }
1780   return Status::OK();
1781 }
1782 
CalculateComputationLayout(HloComputation * computation)1783 Status LayoutAssignment::CalculateComputationLayout(
1784     HloComputation* computation) {
1785   ComputationLayout computation_layout(computation->ComputeProgramShape(),
1786                                        /*ignore_layouts=*/false);
1787   InsertOrDie(&computation_layouts_, computation, computation_layout);
1788   VLOG(2) << "  Calculated ComputationLayout = "
1789           << computation_layout.ToString();
1790   return Status::OK();
1791 }
1792 
ClearComputationLayouts(HloComputation * computation)1793 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1794   // Clear existing layouts of the instructions.  All layouts must be assigned
1795   // by the LayoutAssignment pass, except for those on parameters, the
1796   // computation result, and a couple special cases. The former two are
1797   // specified in computation_layout.  Clearing the layouts here avoids hiding
1798   // potential bugs in the layout assignment pass that may accidentally use the
1799   // existing layout.
1800   for (HloInstruction* instruction : computation->instructions()) {
1801     if (instruction->opcode() == HloOpcode::kBitcast) {
1802       // bitcasts are inherently layout sensitive and so a bitcast instruction
1803       // present in the IR before layout assignment is a bug.
1804       return InternalError(
1805           "Unexpected bitcast operation seen during layout assignment: %s.",
1806           instruction->ToString());
1807     }
1808     // Some instructions carry mandatory layouts in their shape.
1809     if (instruction->opcode() != HloOpcode::kInfeed &&
1810         !IsLayoutConstrainedCustomCall(instruction) &&
1811         !IsLayoutConstrainedAllReduce(instruction)) {
1812       LayoutUtil::ClearLayout(instruction->mutable_shape());
1813     }
1814   }
1815   return Status::OK();
1816 }
1817 
RunOnComputation(ComputationLayout * computation_layout,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1818 Status LayoutAssignment::RunOnComputation(
1819     ComputationLayout* computation_layout, HloComputation* computation,
1820     ChannelLayoutConstraints* channel_constraints) {
1821   VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1822           << ")";
1823 
1824   // Must be run before clearing layouts.
1825   TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1826 
1827   TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1828   if (computation_layout != nullptr) {
1829     auto it = computation_layouts_.find(computation);
1830     if (it == computation_layouts_.end()) {
1831       VLOG(2) << "  New ComputationLayout = " << computation_layout->ToString();
1832       computation_layouts_.emplace(computation, *computation_layout);
1833     } else {
1834       TF_RET_CHECK(computation_layout == &it->second ||
1835                    computation_layout == entry_computation_layout_);
1836       VLOG(2) << "  Existing ComputationLayout = "
1837               << computation_layout->ToString();
1838     }
1839   } else {
1840     VLOG(2) << "  No ComputationLayout specified (will be calculated)";
1841   }
1842 
1843   // Construct LayoutConstraints with all layout constraints of the computation.
1844   LayoutConstraints constraints(*points_to_analysis_, computation);
1845 
1846   // Add constraints required for correctness on all backends (eg, entry
1847   // parameter layout constraints).
1848   TF_RETURN_IF_ERROR(AddMandatoryConstraints(
1849       computation_layout, channel_constraints, computation, &constraints));
1850 
1851   // Add any backend-specific constraints.
1852   TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
1853 
1854   // Propagates layouts from mandatory and backend constraints.
1855   TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1856 
1857   // Prior to applying default layouts, we take note of all HLO instructions
1858   // which lack a layout constraint.
1859   for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
1860     unconstrained_layout_instructions_.insert(
1861         points_to_analysis_->GetBuffer(buffer_id).instruction());
1862   }
1863 
1864   // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
1865   // layout and propagate the change.
1866   while (!constraints.unconstrained_buffer_ids().empty()) {
1867     int unconstrained_count = constraints.unconstrained_buffer_ids().size();
1868 
1869     // Arbitrarily pick the first unconstrained buffer and give it the default
1870     // layout (or the literal layout, in case of constants). By construction
1871     // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
1872     const LogicalBuffer& buffer = points_to_analysis_->GetBuffer(
1873         *constraints.unconstrained_buffer_ids().begin());
1874     const HloInstruction* instruction = buffer.instruction();
1875     Layout new_layout =
1876         instruction->opcode() == HloOpcode::kConstant
1877             ? ShapeUtil::GetSubshape(instruction->literal().shape(),
1878                                      buffer.index())
1879                   .layout()
1880             : LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
1881     TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
1882                                                    /*mandatory=*/false));
1883 
1884     TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1885 
1886     // To verify progress has been made, check that the number of unconstrained
1887     // buffers has been reduced.
1888     CHECK_LT(constraints.unconstrained_buffer_ids().size(),
1889              unconstrained_count);
1890   }
1891   // All logical buffers should have constraints at this point. All that
1892   // remains is assign the constraints to the buffers and infer layouts for
1893   // aliased buffers.
1894   TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
1895 
1896   // If the computation layout wasn't specified, now it is the time to compute
1897   // it according to the parameters and root instruction layouts.
1898   // This allows the first pass through this API to record the best flowing
1899   // layout to parameters and root instruction.
1900   if (computation_layout == nullptr) {
1901     TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
1902   }
1903 
1904   // Record the layouts assigned for any communication ops in
1905   // channel_constraints so that they are constrained for future modules.
1906   if (channel_constraints != nullptr) {
1907     TF_RETURN_IF_ERROR(
1908         ConstrainChannelLayouts(computation, channel_constraints));
1909   }
1910 
1911   // Copy the root instruction's result if its layout does not match the result
1912   // layout constraint.
1913   if (constraints.ResultLayout() != nullptr &&
1914       !constraints.ResultLayout()->MatchesLayoutInShape(
1915           computation->root_instruction()->shape(),
1916           /*minor_to_major_only=*/true)) {
1917     if (conditional_mismatch_.count(computation) > 0) {
1918       *FindOrDie(computation_layouts_, computation).mutable_result_layout() =
1919           FindOrDie(conditional_mismatch_, computation).result_layout();
1920     }
1921     TF_ASSIGN_OR_RETURN(
1922         HloInstruction * new_root,
1923         CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
1924                                 computation->root_instruction()));
1925     computation->set_root_instruction(new_root);
1926   }
1927   return Status::OK();
1928 }
1929 
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1930 Status LayoutAssignment::ConstrainChannelLayouts(
1931     HloComputation* computation,
1932     ChannelLayoutConstraints* channel_constraints) {
1933   auto get_channel_constraints = [&](const HloInstruction* instruction) {
1934     return IsHostSendRecv(instruction) ? &host_channel_constraints_
1935                                        : channel_constraints;
1936   };
1937   // We go through the kRecvDone before. These must either impose their layout,
1938   // or find a matching one already existing (ConstrainChannel() returns
1939   // nullptr).
1940   for (HloInstruction* instruction : computation->instructions()) {
1941     if (instruction->opcode() == HloOpcode::kRecvDone) {
1942       const Layout* layout =
1943           get_channel_constraints(instruction)
1944               ->ConstrainChannel(
1945                   *instruction->channel_id(),
1946                   ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
1947       TF_RET_CHECK(layout == nullptr)
1948           << instruction->ToString()
1949           << " cannot constrain layout as it was set to "
1950           << LayoutUtil::HumanString(*layout);
1951     }
1952   }
1953   // After that we go through the kSend. These are likely going to have a kCopy
1954   // as operand (otherwise we add it), so in case the constrained layout does
1955   // not match, we can change the kCopy layout (and the kSend one as well).
1956   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1957     if (instruction->opcode() == HloOpcode::kSend) {
1958       HloInstruction* operand = instruction->mutable_operand(0);
1959       get_channel_constraints(instruction)
1960           ->ConstrainChannel(*instruction->channel_id(),
1961                              operand->shape().layout());
1962     } else if (instruction->IsCrossModuleAllReduce()) {
1963       get_channel_constraints(instruction)
1964           ->ConstrainChannel(instruction->channel_id().value(),
1965                              instruction->shape().layout());
1966     }
1967   }
1968   return Status::OK();
1969 }
1970 
PropagateMemorySpace(HloModule * module)1971 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
1972   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
1973   for (auto buffer : alias_analysis->buffers()) {
1974     // First go through values to collect the memory spaces.
1975     int64 buffer_memory_space = Layout::kDefaultMemorySpace;
1976     for (auto value : buffer.values()) {
1977       const Shape& defining_shape = value->defining_position().shape();
1978       int64 memory_space = defining_shape.layout().memory_space();
1979       if (memory_space != Layout::kDefaultMemorySpace) {
1980         if (buffer_memory_space != Layout::kDefaultMemorySpace &&
1981             memory_space != buffer_memory_space) {
1982           return InternalError(
1983               "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
1984               buffer.id(), value->ToShortString(), buffer_memory_space,
1985               memory_space);
1986         }
1987         buffer_memory_space = memory_space;
1988       }
1989     }
1990 
1991     // If we encounter a memory space other than the default, then propagate all
1992     // the positions with the buffer's memory space.
1993     if (buffer_memory_space != Layout::kDefaultMemorySpace) {
1994       for (auto value : buffer.values()) {
1995         for (auto& position : value->positions()) {
1996           Shape* shape = ShapeUtil::GetMutableSubshape(
1997               position.instruction->mutable_shape(), position.index);
1998           shape->mutable_layout()->set_memory_space(buffer_memory_space);
1999         }
2000       }
2001     }
2002   }
2003   return Status::OK();
2004 }
2005 
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2006 Status LayoutAssignment::PropagateComputationLayouts(
2007     HloComputation* computation, ComputationLayout* computation_layout) {
2008   ComputationLayout computed_computation_layout(
2009       computation->ComputeProgramShape(),
2010       /*ignore_layouts=*/false);
2011   for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2012     ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2013     bool needs_assign = false;
2014     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2015         param_layout->shape(),
2016         [&](const Shape& subshape, const ShapeIndex& shape_index) {
2017           if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2018             return Status::OK();
2019           }
2020           if (!subshape.has_layout()) {
2021             needs_assign = true;
2022             return Status::OK();
2023           }
2024           const auto& computed_subshape = ShapeUtil::GetSubshape(
2025               computed_computation_layout.parameter_shape(i), shape_index);
2026           if (subshape.layout() != computed_subshape.layout()) {
2027             return InternalError(
2028                 "Assigned parameter shape %s does not match layout of "
2029                 "computation shape: %s",
2030                 computed_computation_layout.ToString(),
2031                 computation_layout->ToString());
2032           }
2033           return Status::OK();
2034         }));
2035     if (needs_assign) {
2036       VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2037               << computation->name() << ": "
2038               << computed_computation_layout.parameter_layout(i).ToString();
2039       *param_layout = computed_computation_layout.parameter_layout(i);
2040     }
2041   }
2042   ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2043   if (!result_layout->LayoutIsSet()) {
2044     VLOG(4) << "Assigning result layout of computation " << computation->name()
2045             << ": " << computed_computation_layout.result_layout().ToString();
2046     *result_layout = computed_computation_layout.result_layout();
2047   } else {
2048     TF_RET_CHECK(
2049         Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2050             computed_computation_layout.result_layout().shape(),
2051             result_layout->shape()));
2052   }
2053   return Status::OK();
2054 }
2055 
Run(HloModule * module)2056 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
2057   VLOG(2) << "Running layout assignment on module " << module->name();
2058   TF_RETURN_IF_ERROR(Init());
2059   call_graph_ = CallGraph::Build(module);
2060   auto computations = module->computations();
2061 
2062   // Add copy to the operand of Send instructions, since we cannot call
2063   // SetOperandLayout on Send instructions as it aliases its input to the
2064   // output.
2065   //
2066   // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2067   // operand buffers that aliases with the output.
2068   for (HloComputation* computation : module->computations()) {
2069     for (HloInstruction* instruction :
2070          computation->MakeInstructionPostOrder()) {
2071       if (instruction->opcode() == HloOpcode::kSend) {
2072         TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2073       }
2074     }
2075   }
2076 
2077   // Clone Conditional computations with multiple callsites.
2078   for (HloComputation* computation : computations) {
2079     CallGraphNode& node = call_graph_->GetNode(computation);
2080     if (node.caller_callsites().size() == 1) {
2081       continue;
2082     }
2083     if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2084           return caller.instruction()->opcode() == HloOpcode::kConditional;
2085         })) {
2086       continue;
2087     }
2088     for (int64 i = 0; i < node.caller_callsites().size() - 1; ++i) {
2089       HloInstruction* caller = node.caller_callsites()[i].instruction();
2090       if (caller->opcode() == HloOpcode::kConditional) {
2091         for (int64 k = 0; k < caller->branch_count(); ++k) {
2092           if (computation == caller->branch_computation(k)) {
2093             caller->set_branch_computation(
2094                 k, module->AddEmbeddedComputation(computation->Clone()));
2095             break;
2096           }
2097         }
2098       }
2099     }
2100   }
2101 
2102   // Verify computation layout is sane.
2103   const HloComputation* entry = module->entry_computation();
2104   TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2105                entry->num_parameters());
2106   for (int64 i = 0; i < entry->num_parameters(); ++i) {
2107     TF_RET_CHECK(
2108         ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2109                               entry->parameter_instruction(i)->shape()));
2110   }
2111   TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2112                                      entry->root_instruction()->shape()));
2113 
2114   // We do two passes. The first one we pass a nullptr ComputationLayout to
2115   // the RunOnComputation() calls (for non entry computations), and we register
2116   // the ComputationLayout which are naturally flowing in DFS fashion to the
2117   // parameters and root instruction.
2118   // Walking in DFS mode though, means that we can end up with incorrect layouts
2119   // when seen from an outer instruction, which has across-computation
2120   // constraints to impose.
2121   // For example, the kWhile instruction needs to enforce the same layouts for
2122   // the parameters and root of the body, as well as the condition parameters.
2123   // Similarly, the kConditional instruction needs to enforce the same layouts
2124   // for the root of the true and false computations.
2125   // So in the first pass, while allowing the layouts to flow to parameters and
2126   // root, we also fix up the eventually inconsistent ComputationLayout, which
2127   // will be then made mandatory by the second pass.
2128   for (int64 i = 0; i < 2; ++i) {
2129     VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2130     TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
2131     TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2132                         TuplePointsToAnalysis::Run(module));
2133     points_to_analysis_ = std::move(points_to_analysis);
2134     for (auto* computation : module->MakeComputationPostOrder()) {
2135       if (computation->IsFusionComputation()) {
2136         continue;
2137       }
2138       if (computation == module->entry_computation()) {
2139         TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_,
2140                                             module->entry_computation(),
2141                                             channel_layout_constraints_));
2142       } else {
2143         ComputationLayout* computation_layout =
2144             (i == 0 || conditional_mismatch_.count(computation) > 0)
2145                 ? nullptr
2146                 : &FindOrDie(computation_layouts_, computation);
2147         TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation,
2148                                             channel_layout_constraints_));
2149       }
2150     }
2151   }
2152   TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2153                                                  entry_computation_layout_));
2154 
2155   TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2156 
2157   TF_RETURN_IF_ERROR(CheckLayouts(module));
2158 
2159   // All layouts are reset then reassigned by this pass.
2160   return true;
2161 }
2162 
2163 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2164 bool LayoutAssignment::InstructionCanChangeLayout(
2165     const HloInstruction* instruction) {
2166   switch (instruction->opcode()) {
2167     case HloOpcode::kAbs:
2168     case HloOpcode::kAdd:
2169     case HloOpcode::kAddDependency:
2170     case HloOpcode::kAnd:
2171     case HloOpcode::kAtan2:
2172     case HloOpcode::kBitcastConvert:
2173     case HloOpcode::kCeil:
2174     case HloOpcode::kClamp:
2175     case HloOpcode::kClz:
2176     case HloOpcode::kCompare:
2177     case HloOpcode::kComplex:
2178     case HloOpcode::kConcatenate:
2179     case HloOpcode::kConditional:
2180     case HloOpcode::kConvert:
2181     case HloOpcode::kCos:
2182     case HloOpcode::kAllToAll:
2183     case HloOpcode::kCollectivePermute:
2184     case HloOpcode::kDivide:
2185     case HloOpcode::kDynamicSlice:
2186     case HloOpcode::kDynamicUpdateSlice:
2187     case HloOpcode::kExp:
2188     case HloOpcode::kExpm1:
2189     case HloOpcode::kFft:
2190     case HloOpcode::kFloor:
2191     case HloOpcode::kImag:
2192     case HloOpcode::kIsFinite:
2193     case HloOpcode::kLog:
2194     case HloOpcode::kLog1p:
2195     case HloOpcode::kMap:
2196     case HloOpcode::kMaximum:
2197     case HloOpcode::kMinimum:
2198     case HloOpcode::kMultiply:
2199     case HloOpcode::kNegate:
2200     case HloOpcode::kNot:
2201     case HloOpcode::kOr:
2202     case HloOpcode::kXor:
2203     case HloOpcode::kPad:
2204     case HloOpcode::kPower:
2205     case HloOpcode::kReal:
2206     case HloOpcode::kReducePrecision:
2207     case HloOpcode::kReduceWindow:
2208     case HloOpcode::kRemainder:
2209     case HloOpcode::kReverse:
2210     case HloOpcode::kRoundNearestAfz:
2211     case HloOpcode::kRsqrt:
2212     case HloOpcode::kScatter:
2213     case HloOpcode::kSelect:
2214     case HloOpcode::kSelectAndScatter:
2215     case HloOpcode::kShiftLeft:
2216     case HloOpcode::kShiftRightArithmetic:
2217     case HloOpcode::kShiftRightLogical:
2218     case HloOpcode::kSign:
2219     case HloOpcode::kSin:
2220     case HloOpcode::kSlice:
2221     case HloOpcode::kSort:
2222     case HloOpcode::kSqrt:
2223     case HloOpcode::kSubtract:
2224     case HloOpcode::kTanh:
2225     case HloOpcode::kPopulationCount:
2226     case HloOpcode::kTriangularSolve:
2227     case HloOpcode::kCholesky:
2228     case HloOpcode::kTupleSelect:
2229     case HloOpcode::kWhile:
2230     case HloOpcode::kSetDimensionSize:
2231     // AllReduce is variadic so it needs to be careful to assign the same layout
2232     // to the corresponding input argument and Tuple index.
2233     case HloOpcode::kAllReduce:
2234       return false;
2235     case HloOpcode::kBatchNormGrad:
2236     case HloOpcode::kBatchNormInference:
2237     case HloOpcode::kBatchNormTraining:
2238     case HloOpcode::kBitcast:
2239     case HloOpcode::kBroadcast:
2240     case HloOpcode::kCall:
2241     case HloOpcode::kConstant:
2242     case HloOpcode::kConvolution:
2243     case HloOpcode::kCopy:
2244     case HloOpcode::kCopyStart:
2245     case HloOpcode::kCopyDone:
2246     case HloOpcode::kCustomCall:
2247     case HloOpcode::kDomain:
2248     case HloOpcode::kDot:
2249     case HloOpcode::kFusion:
2250     case HloOpcode::kGather:
2251     case HloOpcode::kGetTupleElement:
2252     case HloOpcode::kInfeed:
2253     case HloOpcode::kIota:
2254     case HloOpcode::kOutfeed:
2255     case HloOpcode::kParameter:
2256     case HloOpcode::kPartitionId:
2257     case HloOpcode::kRecv:
2258     case HloOpcode::kRecvDone:
2259     case HloOpcode::kReduce:
2260     case HloOpcode::kReplicaId:
2261     case HloOpcode::kReshape:
2262     case HloOpcode::kRng:
2263     case HloOpcode::kRngGetAndUpdateState:
2264     case HloOpcode::kSend:
2265     case HloOpcode::kSendDone:
2266     case HloOpcode::kAfterAll:
2267     case HloOpcode::kTrace:
2268     case HloOpcode::kTranspose:
2269     case HloOpcode::kTuple:
2270     case HloOpcode::kGetDimensionSize:
2271       return true;
2272   }
2273 }
2274 
2275 /* static */
IsAtMostRank1(const Shape & shape)2276 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2277   if (shape.IsArray()) {
2278     return shape.rank() <= 1;
2279   }
2280   return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2281     return IsAtMostRank1(subshape);
2282   });
2283 }
2284 
Init()2285 Status LayoutAssignment::Init() {
2286   computation_layouts_.clear();
2287   conditional_mismatch_.clear();
2288   *entry_computation_layout_ = saved_entry_computation_layout_;
2289   return Status::OK();
2290 }
2291 
ClearPreviousPassSideEffects(HloModule * module)2292 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2293   VLOG(5) << "Clearing previous side effects";
2294   // Clear all the copies which have been added, and all the related
2295   // instructions (like GTE and tuples).
2296   int64 removed_copies = 0;
2297   for (HloComputation* computation : module->computations()) {
2298     for (HloInstruction* instruction :
2299          computation->MakeInstructionPostOrder()) {
2300       if (instruction->opcode() == HloOpcode::kCopy &&
2301           added_copies_.contains(instruction)) {
2302         VLOG(5) << "Removing added copy: " << instruction->ToString();
2303         TF_RETURN_IF_ERROR(
2304             instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2305         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2306         ++removed_copies;
2307       }
2308     }
2309   }
2310   added_copies_.clear();
2311   unconstrained_layout_instructions_.clear();
2312   if (removed_copies > 0) {
2313     TupleSimplifier tuple_simplifier;
2314     HloDCE dce;
2315     TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2316     TF_RETURN_IF_ERROR(dce.Run(module).status());
2317   }
2318   return Status::OK();
2319 }
2320 
AddCopyForOperand(HloInstruction * instruction,int64 operand_number)2321 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2322                                            int64 operand_number) {
2323   HloInstruction* operand = instruction->mutable_operand(operand_number);
2324   if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2325     HloInstruction* copy =
2326         instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2327             operand->shape(), HloOpcode::kCopy, operand));
2328     SetupCopiedInstruction(*operand, copy, {});
2329     LayoutUtil::ClearLayout(copy->mutable_shape());
2330     TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2331   }
2332   return Status::OK();
2333 }
2334 
2335 }  // namespace xla
2336