1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28
29 #include "absl/algorithm/container.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/permutation_util.h"
38 #include "tensorflow/compiler/xla/service/call_graph.h"
39 #include "tensorflow/compiler/xla/service/computation_layout.h"
40 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_dce.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/logical_buffer.h"
48 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
49 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
50 #include "tensorflow/compiler/xla/shape_layout.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/protobuf.h"
61
62 namespace xla {
63
operator <<(std::ostream & out,const LayoutConstraint & constraint)64 std::ostream& operator<<(std::ostream& out,
65 const LayoutConstraint& constraint) {
66 out << constraint.ToString();
67 return out;
68 }
69
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)70 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
71 const LogicalBuffer& buffer,
72 bool mandatory, bool dfs)
73 : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
74 CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
75 }
76
ToString() const77 string BufferLayoutConstraint::ToString() const {
78 return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
79 LayoutUtil::HumanString(layout_));
80 }
81
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)82 OperandLayoutConstraint::OperandLayoutConstraint(
83 const ShapeLayout& shape_layout, const HloInstruction* instruction,
84 int64 operand_no, bool mandatory, bool dfs)
85 : LayoutConstraint(mandatory, dfs),
86 shape_layout_(shape_layout),
87 instruction_(instruction),
88 operand_no_(operand_no) {
89 CHECK(shape_layout_.LayoutIsSet());
90 CHECK(ShapeUtil::Compatible(shape_layout.shape(),
91 instruction->operand(operand_no)->shape()))
92 << shape_layout.shape() << " is not compatible with "
93 << instruction->operand(operand_no)->shape() << " (for operand "
94 << operand_no << " of instruction " << instruction->ToString() << ")";
95 }
96
ToString() const97 string OperandLayoutConstraint::ToString() const {
98 return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
99 instruction_->name(), operand_no_,
100 shape_layout_.ToString());
101 }
102
ToString() const103 string ResultLayoutConstraint::ToString() const {
104 return absl::StrFormat("ResultLayoutConstraint: %s",
105 shape_layout_.ToString());
106 }
107
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)108 LayoutConstraints::LayoutConstraints(
109 const TuplePointsToAnalysis& points_to_analysis,
110 HloComputation* computation)
111 : points_to_analysis_(points_to_analysis), computation_(computation) {
112 // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
113 for (HloInstruction* inst : computation_->instructions()) {
114 points_to_analysis_.GetPointsToSet(inst).ForEachElement(
115 [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
116 for (const LogicalBuffer* buffer : buffers) {
117 // The points to analysis is computed per module, restrict
118 // constraints to array buffers in this computation.
119 if (buffer->IsArray() &&
120 buffer->instruction()->parent() == computation) {
121 unconstrained_buffer_ids_.insert(buffer->id());
122 }
123 }
124 });
125 }
126 }
127
GetBufferSet(const HloInstruction * instruction) const128 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
129 const HloInstruction* instruction) const {
130 auto it = buffer_sets_cache_.find(instruction);
131 if (it != buffer_sets_cache_.end()) {
132 return it->second.get();
133 }
134 auto& buffer_set =
135 buffer_sets_cache_
136 .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
137 .first->second;
138 const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
139 points_to_set.ForEachElement(
140 [&buffer_set](const ShapeIndex& /*index*/,
141 const PointsToSet::BufferList& buffers) {
142 buffer_set->insert(buffers.begin(), buffers.end());
143 });
144 return buffer_set.get();
145 }
146
OperandBufferForwarded(const HloInstruction * instruction,int64 operand_no) const147 bool LayoutConstraints::OperandBufferForwarded(
148 const HloInstruction* instruction, int64 operand_no) const {
149 // The operand is potentially forwarded if the intersection of points-to sets
150 // of the operand and the instruction is non-empty.
151 PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
152 PointsToSet::BufferSet* operand_buffers =
153 GetBufferSet(instruction->operand(operand_no));
154 return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
155 return operand_buffers->count(b) > 0;
156 });
157 }
158
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)159 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
160 const LogicalBuffer& buffer,
161 bool mandatory, bool dfs) {
162 VLOG(3) << "SetBufferLayout : " << buffer << " : "
163 << LayoutUtil::HumanString(layout);
164
165 TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
166 if (!buffer.IsArray()) {
167 return FailedPrecondition(
168 "Layout of buffer %s cannot be constrained because buffer is not "
169 "array-shaped, has shape: %s",
170 buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
171 }
172 TF_RETURN_IF_ERROR(
173 LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
174
175 auto iter = buffer_constraints_.find(&buffer);
176 if (iter != buffer_constraints_.end()) {
177 const BufferLayoutConstraint& curr_constraint = iter->second;
178 if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
179 // New constraint matches existing constraint. Nothing to do.
180 return Status::OK();
181 }
182 if (curr_constraint.mandatory()) {
183 if (!mandatory) {
184 VLOG(3) << "Buffer" << buffer
185 << " already has a mandatory layout constrain, skipping";
186 return Status::OK();
187 }
188 return FailedPrecondition(
189 "Buffer %s already has the layout constraint %s, cannot add "
190 "incompatible constraint %s",
191 buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
192 LayoutUtil::HumanString(layout));
193 }
194 iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
195 } else {
196 TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
197 << buffer.ToString();
198 iter = buffer_constraints_
199 .insert(std::make_pair(
200 &buffer,
201 BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
202 .first;
203 }
204 added_constraints_.push_back(&iter->second);
205 return Status::OK();
206 }
207
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)208 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
209 const HloInstruction* instruction,
210 int64 operand_no, bool mandatory,
211 bool dfs) {
212 VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
213 << operand_no << " : "
214 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
215
216 const OperandLayoutConstraint* curr_shape_layout =
217 GetOperandLayoutConstraint(instruction, operand_no);
218 if (curr_shape_layout != nullptr) {
219 if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
220 shape_with_layout, /*minor_to_major_only=*/true)) {
221 // New constraint matches existing constraint. Nothing to do.
222 return Status::OK();
223 }
224 if (curr_shape_layout->mandatory()) {
225 return FailedPrecondition(
226 "Operand %d of instruction %s already has a layout constraint "
227 "%s, cannot add incompatible constraint %s",
228 operand_no, instruction->name(),
229 curr_shape_layout->shape_layout().ToString(),
230 ShapeUtil::HumanStringWithLayout(shape_with_layout));
231 }
232 }
233
234 // If any buffers in the operand occur in the output of the instruction, then
235 // return an error. This case is not handled because such a constraint changes
236 // layouts beyond this immediate use and is complicated to handle.
237 if (OperandBufferForwarded(instruction, operand_no)) {
238 return FailedPrecondition(
239 "Cannot constraint layout of operand %d of instruction %s "
240 "because instruction forwards operand's LogicalBuffer(s)",
241 operand_no, instruction->name());
242 }
243
244 auto key = std::make_pair(instruction, operand_no);
245 auto iter = operand_constraints_.find(key);
246 if (iter == operand_constraints_.end()) {
247 auto pair = std::make_pair(
248 key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
249 instruction, operand_no, mandatory, dfs));
250 iter = operand_constraints_.insert(pair).first;
251 } else {
252 iter->second =
253 OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
254 operand_no, mandatory, dfs);
255 }
256 added_constraints_.push_back(&iter->second);
257
258 return Status::OK();
259 }
260
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)261 Status LayoutConstraints::SetArrayOperandLayout(
262 const Layout& layout, const HloInstruction* instruction, int64 operand_no,
263 bool mandatory, bool dfs) {
264 const HloInstruction* operand = instruction->operand(operand_no);
265 TF_RET_CHECK(operand->shape().IsArray());
266 Shape shape(operand->shape());
267 *shape.mutable_layout() = layout;
268 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
269 return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
270 }
271
SetResultLayout(const Shape & shape_with_layout,bool dfs)272 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
273 bool dfs) {
274 VLOG(3) << "SetResultLayout : "
275 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
276
277 const ShapeLayout* curr_shape_layout = ResultLayout();
278 if (curr_shape_layout != nullptr) {
279 if (!curr_shape_layout->MatchesLayoutInShape(
280 shape_with_layout, /*minor_to_major_only=*/true)) {
281 return FailedPrecondition(
282 "Result of computation %s already has the layout constraint %s, "
283 "cannot add incompatible constraint %s",
284 computation_->name(), curr_shape_layout->ToString(),
285 ShapeUtil::HumanStringWithLayout(shape_with_layout));
286 }
287 // New constraint matches existing constraint. Nothing to do.
288 return Status::OK();
289 }
290 result_constraint_.reset(
291 new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
292 added_constraints_.push_back(result_constraint_.get());
293
294 return Status::OK();
295 }
296
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs)297 Status LayoutConstraints::SetInstructionLayout(
298 const Shape& shape_with_layout, const HloInstruction* instruction,
299 bool mandatory, bool dfs) {
300 VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
301 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
302
303 if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
304 return FailedPrecondition(
305 "Instruction %s of shape %s cannot be assigned incompatible layout %s",
306 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
307 ShapeUtil::HumanStringWithLayout(shape_with_layout));
308 }
309
310 // Create a BufferLayoutConstraint for each array shape in the output of the
311 // instruction.
312 return ShapeUtil::ForEachSubshapeWithStatus(
313 shape_with_layout,
314 [this, instruction, mandatory](const Shape& subshape,
315 const ShapeIndex& index) -> Status {
316 // The precondition for this method is that the instruction defines all
317 // buffers in its output.
318 auto buffers =
319 points_to_analysis_.GetPointsToSet(instruction).element(index);
320 CHECK_EQ(1, buffers.size());
321 CHECK_EQ(buffers[0]->instruction(), instruction);
322
323 if (subshape.IsArray() && subshape.has_layout()) {
324 return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
325 } else {
326 return Status::OK();
327 }
328 });
329 }
330
BufferLayout(const LogicalBuffer & buffer) const331 const Layout* LayoutConstraints::BufferLayout(
332 const LogicalBuffer& buffer) const {
333 if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
334 return &constraint->layout();
335 }
336 return nullptr;
337 }
338
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const339 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
340 const LogicalBuffer& buffer) const {
341 auto it = buffer_constraints_.find(&buffer);
342 return it == buffer_constraints_.end() ? nullptr : &it->second;
343 }
344
OperandLayout(const HloInstruction * instruction,int64 operand_no) const345 const ShapeLayout* LayoutConstraints::OperandLayout(
346 const HloInstruction* instruction, int64 operand_no) const {
347 if (const auto* constraint =
348 GetOperandLayoutConstraint(instruction, operand_no)) {
349 return &constraint->shape_layout();
350 }
351 return nullptr;
352 }
353
GetOperandLayoutConstraint(const HloInstruction * instruction,int64 operand_no) const354 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
355 const HloInstruction* instruction, int64 operand_no) const {
356 auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
357 return it == operand_constraints_.end() ? nullptr : &it->second;
358 }
359
ResultLayout() const360 const ShapeLayout* LayoutConstraints::ResultLayout() const {
361 return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
362 }
363
ToString() const364 string LayoutConstraints::ToString() const {
365 string output;
366 absl::StrAppend(&output, "LayoutConstraints for computation ",
367 computation_->name(), ":\n");
368 for (auto* instruction : computation_->MakeInstructionPostOrder()) {
369 absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
370 for (int64 i = 0; i < instruction->operand_count(); ++i) {
371 if (OperandLayout(instruction, i) != nullptr) {
372 absl::StrAppend(&output, " operand (", i,
373 "): ", OperandLayout(instruction, i)->ToString(), "\n");
374 }
375 }
376 for (const LogicalBuffer* buffer :
377 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
378 if (BufferLayout(*buffer) != nullptr) {
379 absl::StrAppend(&output, " ", buffer->ToString(), " : ",
380 LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
381 }
382 }
383 }
384
385 if (ResultLayout() != nullptr) {
386 absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
387 }
388 return output;
389 }
390
391 namespace {
392
IsHostSendRecv(const HloInstruction * instruction)393 bool IsHostSendRecv(const HloInstruction* instruction) {
394 const HloSendRecvInstruction* send_recv_instr =
395 DynCast<HloSendRecvInstruction>(instruction);
396 return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
397 }
398
399 } // namespace
400
BuildHostChannelConstraints(HloComputation * computation)401 Status LayoutAssignment::BuildHostChannelConstraints(
402 HloComputation* computation) {
403 for (auto* instruction : computation->instructions()) {
404 const HloSendRecvInstruction* send_recv_instr =
405 DynCast<HloSendRecvInstruction>(instruction);
406 if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
407 continue;
408 }
409
410 // For host transfers the Send and Recv instruction carry the layout.
411 if (instruction->opcode() == HloOpcode::kSend ||
412 instruction->opcode() == HloOpcode::kRecv) {
413 const Shape& data_shape =
414 ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
415 TF_RET_CHECK(data_shape.IsArray());
416 TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
417 const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
418 *send_recv_instr->channel_id(), data_shape.layout());
419 TF_RET_CHECK(prev_layout == nullptr)
420 << "Cannot constrain host transfer layout as it was set to "
421 << LayoutUtil::HumanString(*prev_layout) << ": "
422 << send_recv_instr->ToString();
423 }
424 }
425 return Status::OK();
426 }
427
428 namespace {
429
IsLayoutConstrainedCustomCall(HloInstruction * instruction)430 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
431 const HloCustomCallInstruction* custom_call =
432 DynCast<HloCustomCallInstruction>(instruction);
433 return custom_call != nullptr && custom_call->layout_constrained();
434 }
435
IsLayoutConstrainedCollective(const HloInstruction * instruction)436 bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
437 const HloCollectiveInstruction* collective =
438 DynCast<HloCollectiveInstruction>(instruction);
439 return collective != nullptr && collective->constrain_layout();
440 }
441
442 } // namespace
443
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)444 Status LayoutAssignment::AddMandatoryConstraints(
445 const ComputationLayout* computation_layout,
446 ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
447 LayoutConstraints* constraints) {
448 VLOG(3) << "Adding mandatory layout constraints to computation "
449 << computation->name();
450
451 auto get_channel_constraints = [&](const HloInstruction* instruction) {
452 return IsHostSendRecv(instruction) ? &host_channel_constraints_
453 : channel_constraints;
454 };
455
456 // Constrain layouts of instructions which define values with pre-existing
457 // layouts.
458 for (auto* instruction : computation->instructions()) {
459 if (instruction->opcode() == HloOpcode::kInfeed) {
460 // Infeed layouts must match the layout of the original inserted
461 // instruction.
462 // TODO(b/31425034): Change infeeds to be more like parameters, with
463 // shapes in the ComputationLayout.
464 TF_RETURN_IF_ERROR(
465 constraints->SetInstructionLayout(instruction->shape(), instruction));
466 } else if (instruction->opcode() == HloOpcode::kOutfeed) {
467 // Constrain the input to the Outfeed instruction to be the expected
468 // layout of the Outfeed.
469 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
470 instruction->outfeed_shape(), instruction, 0));
471 } else if (instruction->opcode() == HloOpcode::kParameter) {
472 if (computation_layout != nullptr) {
473 const ShapeLayout& parameter_layout =
474 computation_layout->parameter_layout(
475 instruction->parameter_number());
476 // Parameter layouts must match the respective layout in
477 // ComputationLayout, if there is one.
478 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
479 parameter_layout.shape(), instruction));
480 }
481 } else if (IsLayoutConstrainedCustomCall(instruction)) {
482 const HloCustomCallInstruction* custom_call =
483 DynCast<HloCustomCallInstruction>(instruction);
484 TF_RETURN_IF_ERROR(
485 constraints->SetInstructionLayout(custom_call->shape(), custom_call));
486 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
487 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
488 custom_call->operand_shapes_with_layout()[i], custom_call, i));
489 }
490 } else if (instruction->opcode() == HloOpcode::kSend ||
491 instruction->opcode() == HloOpcode::kRecv) {
492 CHECK(get_channel_constraints(instruction))
493 << "Multi-module layout assignment requires ChannelLayoutConstraints";
494 int64 channel_id = *instruction->channel_id();
495 if (!get_channel_constraints(instruction)
496 ->IsChannelConstrained(channel_id)) {
497 continue;
498 }
499 if (instruction->opcode() == HloOpcode::kSend) {
500 // TODO(b/68493863): Change to use SetOperandLayout().
501 const Shape send_buffer_shape = instruction->operand(0)->shape();
502 TF_RET_CHECK(send_buffer_shape.IsArray());
503 Shape new_buffer_shape =
504 get_channel_constraints(instruction)
505 ->LayoutShapeForChannel(send_buffer_shape,
506 *instruction->channel_id());
507 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
508 new_buffer_shape, instruction->operand(0)));
509 } else {
510 const Shape recv_buffer_shape =
511 ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
512 TF_RET_CHECK(recv_buffer_shape.IsArray());
513 TF_ASSIGN_OR_RETURN(
514 const LogicalBuffer* buffer,
515 constraints->points_to_analysis().GetBufferDefinedAt(instruction,
516 {0}));
517 Shape new_shape =
518 get_channel_constraints(instruction)
519 ->LayoutShapeForChannel(recv_buffer_shape,
520 *instruction->channel_id());
521 TF_RETURN_IF_ERROR(
522 constraints->SetBufferLayout(new_shape.layout(), *buffer));
523 }
524 } else if (IsLayoutConstrainedCollective(instruction)) {
525 TF_RETURN_IF_ERROR(
526 constraints->SetInstructionLayout(instruction->shape(), instruction));
527 } else if (instruction->IsCrossModuleAllReduce()) {
528 CHECK(get_channel_constraints(instruction))
529 << "Multi-module layout assignment requires ChannelLayoutConstraints";
530 int64 channel_id = instruction->channel_id().value();
531 if (!get_channel_constraints(instruction)
532 ->IsChannelConstrained(channel_id)) {
533 continue;
534 }
535 // TODO(b/68493863): Change to use SetOperandLayout().
536 const Shape& buffer_shape = instruction->operand(0)->shape();
537 TF_RET_CHECK(buffer_shape.IsArray());
538 Shape new_buffer_shape =
539 get_channel_constraints(instruction)
540 ->LayoutShapeForChannel(buffer_shape, channel_id);
541 TF_RETURN_IF_ERROR(
542 constraints->SetInstructionLayout(new_buffer_shape, instruction));
543 }
544 }
545
546 // Constrain layouts of instructions which call computations which have
547 // already been assigned layouts. Instructions which call computations in a
548 // parallel element-wise context (eg, map or reduce) do not need layout
549 // constraints because they operate on scalars.
550 for (auto* instruction : computation->instructions()) {
551 if (instruction->opcode() == HloOpcode::kCall) {
552 // kCall instruction operands and output must match the ComputationLayout
553 // of the called computation.
554 const ComputationLayout& called_computation_layout =
555 FindOrDie(computation_layouts_, instruction->to_apply());
556 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
557 called_computation_layout.result_layout().shape(), instruction));
558 TF_RET_CHECK(instruction->operand_count() ==
559 called_computation_layout.parameter_count());
560 for (int64 i = 0; i < instruction->operand_count(); ++i) {
561 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
562 called_computation_layout.parameter_layout(i).shape(), instruction,
563 i));
564 }
565 } else if (instruction->opcode() == HloOpcode::kWhile) {
566 // Layout of input and output of kWhile instruction must be equal and must
567 // match both input and output of body computation. Also, the input of
568 // condition computation must match kWhile layout.
569 HloComputation* body = instruction->while_body();
570 HloComputation* condition = instruction->while_condition();
571 const HloInstruction* init = instruction->operand(0);
572 ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
573 ComputationLayout& condition_layout =
574 FindOrDie(computation_layouts_, condition);
575
576 // Check a few invariants irrespective of layout.
577 CHECK_EQ(1, instruction->operand_count());
578 CHECK_EQ(1, body->num_parameters());
579 CHECK_EQ(1, condition->num_parameters());
580 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
581 body_layout.parameter_shape(0)));
582 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
583 condition_layout.parameter_shape(0)));
584 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
585
586 if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
587 VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
588 << " while=" << instruction->name()
589 << " shape=" << body_layout.result_layout().ToString();
590 *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
591 }
592 if (condition_layout.parameter_layout(0) !=
593 body_layout.parameter_layout(0)) {
594 VLOG(2) << "Reset %while condition parameter layout: cond="
595 << condition->name() << " while=" << instruction->name()
596 << " shape=" << body_layout.parameter_layout(0).ToString();
597 *condition_layout.mutable_parameter_layout(0) =
598 body_layout.parameter_layout(0);
599 }
600
601 // Constrain the output and the operand of the while instruction to match
602 // the computations.
603 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
604 body_layout.result_shape(), instruction, 0));
605 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
606 body_layout.result_shape(), instruction));
607 } else if (instruction->opcode() == HloOpcode::kConditional) {
608 // Find the conditional branch with the most instructions and force all
609 // other computations to match that layout. A potentially better decision
610 // could count the number FLOPs or how constrained the layouts are.
611 int64 largest_branch = 0;
612 int64 largest_instruction_count =
613 instruction->branch_computation(0)->instruction_count();
614 for (int j = 1; j < instruction->branch_count(); ++j) {
615 const int64 instruction_count =
616 instruction->branch_computation(j)->instruction_count();
617 if (instruction_count > largest_instruction_count) {
618 largest_branch = j;
619 largest_instruction_count = instruction_count;
620 }
621 }
622 ComputationLayout& best_branch_computation_layout =
623 FindOrDie(computation_layouts_,
624 instruction->branch_computation(largest_branch));
625 for (int k = 0; k < instruction->branch_count(); ++k) {
626 // Visit the best branch first.
627 int j = (k + largest_branch) % instruction->branch_count();
628 TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
629 ComputationLayout& branch_computation_layout =
630 FindOrDie(computation_layouts_, instruction->branch_computation(k));
631 if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
632 best_branch_computation_layout.result_layout().shape(),
633 /*minor_to_major_only=*/true)) {
634 computation_layouts_.erase(instruction->branch_computation(k));
635 InsertOrDie(&conditional_mismatch_,
636 instruction->branch_computation(k),
637 best_branch_computation_layout);
638 } else {
639 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
640 branch_computation_layout.parameter_shape(0), instruction, k + 1,
641 /*mandatory=*/true));
642 }
643 }
644 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
645 best_branch_computation_layout.parameter_shape(0), instruction,
646 largest_branch + 1,
647 /*mandatory=*/true));
648 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
649 best_branch_computation_layout.result_shape(), instruction));
650 }
651 }
652 // Finally set the result layout to match ComputationLayout, if there is one.
653 if (conditional_mismatch_.count(computation) > 0) {
654 TF_RETURN_IF_ERROR(constraints->SetResultLayout(
655 FindOrDie(conditional_mismatch_, computation).result_layout().shape()));
656 } else if (computation_layout != nullptr) {
657 const ShapeLayout& result_layout = computation_layout->result_layout();
658 if (result_layout.LayoutIsSet()) {
659 TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
660 }
661 }
662 return Status::OK();
663 }
664
665 namespace {
666
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)667 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
668 return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
669 }
670
671 // The operands of a call must match the layouts of parameters in the
672 // ComputationLayout, and the call instruction itself must match the result
673 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)674 Status CheckCallLayout(HloInstruction* call,
675 const ComputationLayout& computation_layout) {
676 HloComputation* computation = call->to_apply();
677 TF_RET_CHECK(computation->num_parameters() == call->operand_count());
678 for (int64 i = 0; i < computation->num_parameters(); ++i) {
679 TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
680 call->operand(i)->shape(), /*minor_to_major_only=*/true));
681 }
682 TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
683 call->shape(), /*minor_to_major_only=*/true));
684 return Status::OK();
685 }
686
687 // Operands of layout-constrained custom calls must match the expected
688 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)689 Status CheckCustomCallLayout(HloInstruction* instruction) {
690 if (IsLayoutConstrainedCustomCall(instruction)) {
691 const HloCustomCallInstruction* custom_call =
692 DynCast<HloCustomCallInstruction>(instruction);
693 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
694 TF_RET_CHECK(
695 LayoutsInShapesEqual(custom_call->operand(i)->shape(),
696 custom_call->operand_shapes_with_layout()[i]));
697 }
698 }
699 return Status::OK();
700 }
701
702 // For a while instruction, all the following layouts must be the same:
703 // (1) init operand
704 // (2) condition computation parameter
705 // (3) body computation parameter
706 // (4) body computation result
707 // (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)708 Status CheckWhileLayout(HloInstruction* while_inst,
709 const ComputationLayout& condition_computation_layout,
710 const ComputationLayout& body_computation_layout) {
711 auto init_shape = while_inst->operand(0)->shape();
712 TF_RET_CHECK(
713 condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
714 init_shape, /*minor_to_major_only=*/true));
715 TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
716 init_shape, /*minor_to_major_only=*/true));
717 TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
718 init_shape, /*minor_to_major_only=*/true));
719 TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
720 return Status::OK();
721 }
722
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)723 Status CheckConditionalLayout(
724 HloInstruction* instruction,
725 absl::Span<const ComputationLayout> branch_computation_layouts) {
726 for (int j = 0; j < instruction->branch_count(); ++j) {
727 const HloInstruction* branch_operand = instruction->operand(j + 1);
728 TF_RET_CHECK(
729 branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
730 branch_computation_layouts[j].result_layout().shape(),
731 /*minor_to_major_only=*/true));
732 TF_RET_CHECK(
733 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
734 instruction->shape(), /*minor_to_major_only=*/true));
735 TF_RET_CHECK(
736 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
737 instruction->branch_computation(j)->root_instruction()->shape(),
738 /*minor_to_major_only=*/true));
739 TF_RET_CHECK(
740 branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
741 branch_operand->shape(), /*minor_to_major_only=*/true));
742 }
743 return Status::OK();
744 }
745
746 // Fusion parameters must match the layout of the fusion instructions operands,
747 // and the root of the fusion expression must match the layout of the fusion
748 // instruction.
CheckFusionLayout(HloInstruction * fusion)749 Status CheckFusionLayout(HloInstruction* fusion) {
750 TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
751
752 TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
753 fusion->fused_expression_root()->shape()));
754 for (int64 i = 0; i < fusion->operand_count(); ++i) {
755 TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
756 fusion->operand(i)->shape()));
757 }
758 return Status::OK();
759 }
760
761 // The layout of a parameter must match the respective layout in the
762 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)763 Status CheckParameterLayout(HloInstruction* parameter,
764 const ComputationLayout& computation_layout) {
765 const ShapeLayout& parameter_layout =
766 computation_layout.parameter_layout(parameter->parameter_number());
767 return ShapeUtil::ForEachSubshapeWithStatus(
768 parameter_layout.shape(),
769 [&](const Shape& subshape, const ShapeIndex& shape_index) {
770 if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
771 !subshape.has_layout()) {
772 return Status::OK();
773 }
774 if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
775 subshape,
776 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
777 return InternalError(
778 "parameter instruction %s does not match layout of computation "
779 "shape: %s",
780 parameter->ToString(), parameter_layout.ToString());
781 }
782 return Status::OK();
783 });
784 }
785
786 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)787 Status CheckConstantLayout(HloInstruction* constant) {
788 if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
789 return InternalError(
790 "constant instruction %s does not match the layout of its literal %s",
791 constant->ToString(),
792 ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
793 }
794 return Status::OK();
795 }
796
797 } // namespace
798
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)799 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
800 const Shape& shape_with_layout, HloInstruction* instruction) {
801 TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
802 DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
803 << ShapeUtil::HumanString(shape_with_layout) << " "
804 << ShapeUtil::HumanString(instruction->shape())
805 << " instruction: " << instruction->ToString();
806
807 if (instruction->shape().IsTuple()) {
808 // Copy tuple elements which have differing layouts.
809 std::vector<HloInstruction*> element_copies;
810 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
811 ++i) {
812 const Shape& target_shape =
813 ShapeUtil::GetSubshape(shape_with_layout, {i});
814 const Shape& instr_shape =
815 ShapeUtil::GetSubshape(instruction->shape(), {i});
816 HloInstruction* gte = instruction->parent()->AddInstruction(
817 HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
818
819 if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
820 instr_shape)) {
821 // Shapes and layouts are equal, no need to copy.
822 element_copies.push_back(gte);
823 } else {
824 SetupCopiedInstruction(*instruction, gte, {i});
825 // Recurse to copy each element.
826 TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
827 CreateCopyWithNewLayout(target_shape, gte));
828 element_copies.push_back(element_copy);
829 }
830 }
831 // Gather element copies into a tuple with a new Tuple instruction.
832 HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
833 HloInstruction::CreateTuple(element_copies));
834 SetupCopiedInstruction(*instruction, tuple_copy, {});
835 LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
836 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
837 shape_with_layout, tuple_copy->mutable_shape()));
838 return tuple_copy;
839 } else if (instruction->shape().IsArray()) {
840 HloInstruction* copy =
841 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
842 instruction->shape(), HloOpcode::kCopy, instruction));
843 RegisterAddedCopy(copy);
844 SetupCopiedInstruction(*instruction, copy, {});
845 LayoutUtil::ClearLayout(copy->mutable_shape());
846 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
847 shape_with_layout, copy->mutable_shape()));
848
849 return copy;
850 } else {
851 return FailedPrecondition(
852 "Can only copy array and tuple shaped instructions");
853 }
854 }
855
856 // Creates a copy of the given operand if the operand's layout does not match
857 // the given layout. This copy replaces the use in the given instruction. Tuple
858 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64 operand_no)859 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
860 const ShapeLayout& operand_layout, HloInstruction* instruction,
861 int64 operand_no) {
862 HloInstruction* operand = instruction->mutable_operand(operand_no);
863 TF_RET_CHECK(operand_layout.LayoutIsSet());
864 TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
865
866 if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
867 operand->shape())) {
868 VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
869 << instruction->ToString();
870 // Operand layout already matches our constraint. Nothing to do.
871 return Status::OK();
872 }
873 VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
874 << operand_layout.ToString() << " in " << instruction->ToString();
875
876 // If the operand is only used by a conditional, do the copy inside the branch
877 // to avoid overhead for other branches.
878 if (instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
879 instruction->operand(operand_no)->user_count() == 1) {
880 auto branch_comp = instruction->branch_computation(operand_no - 1);
881 auto param = branch_comp->parameter_instruction(0);
882 *param->mutable_shape() = operand->shape();
883 auto param_users = param->users();
884 TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
885 CreateCopyWithNewLayout(operand_layout.shape(), param));
886 for (auto user : param_users) {
887 TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
888 }
889 VLOG(4) << "New copy of " << operand->ToString() << " is "
890 << param_copy->ToString();
891 if (param == branch_comp->root_instruction()) {
892 branch_comp->set_root_instruction(param_copy,
893 /*accept_different_shape=*/true);
894 }
895 *FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) =
896 ShapeLayout(operand->shape());
897 return Status::OK();
898 }
899
900 TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
901 CreateCopyWithNewLayout(operand_layout.shape(), operand));
902
903 VLOG(4) << "New copy of " << operand->ToString() << " is "
904 << operand_copy->ToString();
905 return instruction->ReplaceOperandWith(operand_no, operand_copy);
906 }
907
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)908 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
909 HloInstruction* copy,
910 const ShapeIndex& index) {
911 if (instruction.has_sharding()) {
912 // If the index is empty, we want to copy the whole sharding, in case the
913 // sharding is a tuple sharding.
914 HloSharding sharding =
915 !index.empty() && instruction.sharding().IsTuple()
916 ? instruction.sharding().GetSubSharding(instruction.shape(), index)
917 : instruction.sharding();
918 // We propagate the sharding to the copied instruction only if it is a
919 // special sharding, like tiled ones.
920 // Otherwise it is preferable to leave the new instruction without device,
921 // and let the automatic device placer to choose the best location.
922 auto device = sharding.UniqueDevice();
923 if (!device || HloSharding::IsReservedDevice(*device)) {
924 copy->set_sharding(sharding);
925 }
926 }
927 copy->set_metadata(instruction.metadata());
928 }
929
CheckLayouts(HloModule * module)930 Status LayoutAssignment::CheckLayouts(HloModule* module) {
931 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
932 TuplePointsToAnalysis::Run(module));
933 for (auto* computation : module->MakeNonfusionComputations()) {
934 for (auto* instruction : computation->instructions()) {
935 // Verify every instruction has a layout and the layout is valid for the
936 // shape.
937 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
938 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
939
940 // Use points-to analysis to verify that every subshape element in the
941 // output of the instruction matches the layout of the logical buffer
942 // which could be the source of the subshape value.
943 const PointsToSet& points_to_set =
944 points_to_analysis->GetPointsToSet(instruction);
945 TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
946 [&instruction](ShapeIndex index,
947 const PointsToSet::BufferList& buffers) -> Status {
948 if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
949 const Shape& instruction_subshape =
950 ShapeUtil::GetSubshape(instruction->shape(), index);
951 for (const LogicalBuffer* buffer : buffers) {
952 if (!Shape::Equal()
953 .IgnoreDynamicDimension()
954 .MinorToMajorOnlyInLayout()(instruction_subshape,
955 buffer->shape())) {
956 return InternalError(
957 "Layout of instruction %s at index {%s} does not match "
958 "source LogicalBuffer %s: %s vs %s",
959 instruction->name(), absl::StrJoin(index, ","),
960 buffer->ToString(),
961 ShapeUtil::HumanStringWithLayout(instruction_subshape),
962 ShapeUtil::HumanStringWithLayout(buffer->shape()));
963 }
964 }
965 }
966 return Status::OK();
967 }));
968
969 // Verify instructions that have special layout constraints.
970 switch (instruction->opcode()) {
971 case HloOpcode::kCall:
972 TF_RETURN_IF_ERROR(CheckCallLayout(
973 instruction,
974 FindOrDie(computation_layouts_, instruction->to_apply())));
975 break;
976 case HloOpcode::kCustomCall:
977 TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
978 break;
979 case HloOpcode::kFusion:
980 TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
981 break;
982 case HloOpcode::kParameter:
983 TF_RETURN_IF_ERROR(CheckParameterLayout(
984 instruction,
985 FindOrDie(computation_layouts_, instruction->parent())));
986 break;
987 case HloOpcode::kConstant:
988 TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
989 break;
990 case HloOpcode::kWhile:
991 TF_RETURN_IF_ERROR(CheckWhileLayout(
992 instruction,
993 FindOrDie(computation_layouts_, instruction->while_condition()),
994 FindOrDie(computation_layouts_, instruction->while_body())));
995 break;
996 case HloOpcode::kConditional: {
997 std::vector<ComputationLayout> branch_computation_layouts;
998 for (auto branch_computation : instruction->branch_computations()) {
999 branch_computation_layouts.emplace_back(
1000 FindOrDie(computation_layouts_, branch_computation));
1001 }
1002 TF_RETURN_IF_ERROR(CheckConditionalLayout(
1003 instruction, absl::MakeSpan(branch_computation_layouts)));
1004 break;
1005 }
1006 default:
1007 break;
1008 }
1009 }
1010 }
1011 // Finally verify the result layout, if set, matches the layout of the entry
1012 // computation root.
1013 const ShapeLayout& result_layout =
1014 FindOrDie(computation_layouts_, module->entry_computation())
1015 .result_layout();
1016 if (result_layout.LayoutIsSet()) {
1017 TF_RET_CHECK(
1018 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1019 module->result_shape(), result_layout.shape()));
1020 }
1021 return Status::OK();
1022 }
1023
LayoutAssignment(ComputationLayout * entry_computation_layout,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func,ChannelLayoutConstraints * channel_constraints)1024 LayoutAssignment::LayoutAssignment(
1025 ComputationLayout* entry_computation_layout,
1026 std::function<bool(const HloInstruction*)>
1027 instruction_can_change_layout_func,
1028 ChannelLayoutConstraints* channel_constraints)
1029 : entry_computation_layout_(entry_computation_layout),
1030
1031 saved_entry_computation_layout_(*entry_computation_layout),
1032 channel_layout_constraints_(channel_constraints),
1033 instruction_can_change_layout_func_(
1034 std::move(instruction_can_change_layout_func)) {
1035 if (channel_layout_constraints_ != nullptr) {
1036 // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1037 // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1038 channel_constraints_ = *channel_layout_constraints_;
1039 }
1040 VLOG(1) << "Entry computation layout given to layout assignment: "
1041 << entry_computation_layout_->ToString();
1042 }
1043
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64 operand_no)1044 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1045 const Layout& output_layout, const HloInstruction* instruction,
1046 int64 operand_no) {
1047 const HloInstruction* operand = instruction->operand(operand_no);
1048 CHECK(instruction->shape().IsArray());
1049 CHECK(operand->shape().IsArray());
1050 if (!ShapeUtil::IsScalar(operand->shape()) &&
1051 operand->shape().rank() == instruction->shape().rank() &&
1052 !instruction_can_change_layout_func_(instruction)) {
1053 // Propagate the result layout to the operand layout if the instruction
1054 // requires the same layout out for the result and the operand.
1055 //
1056 // For elementwise operations, using the same layout for the operands and
1057 // the result also has the following benefits:
1058 // 1) the elementwise operation can reuse its operand's buffer, and
1059 // 2) the input and output elements can reuse the same linear index.
1060 return absl::make_unique<Layout>(output_layout);
1061 }
1062
1063 if (instruction->opcode() == HloOpcode::kReshape) {
1064 // Prefer the operand layout that makes the reshape an bitcast. If any
1065 // dimension bound is 1 in the operand shape, there may be several such
1066 // layouts. So if 'output_layout' is the default layout, try if the
1067 // reshape is a bitcast when using the same layout. This may avoid copy
1068 // operations. For similar reasons, if the operand and output have the same
1069 // rank, try to match the operand's layout to the output.
1070 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1071 ShapeUtil::TrueRank(instruction->shape()) == 1) {
1072 // Don't assign a layout in case of R1 -> effective R1 reshape.
1073 return nullptr;
1074 }
1075
1076 const Shape& output_shape = instruction->shape();
1077 Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1078 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1079 LayoutUtil::MinorToMajor(output_layout));
1080 Shape operand_shape = operand->shape();
1081 *operand_shape.mutable_layout() =
1082 LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1083 auto aligned_operand_shape =
1084 ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1085 if (aligned_operand_shape) {
1086 auto operand_layout = aligned_operand_shape.value().layout();
1087 TF_CHECK_OK(
1088 LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1089 return absl::make_unique<Layout>(operand_layout);
1090 }
1091 }
1092
1093 if (instruction->opcode() == HloOpcode::kTranspose) {
1094 // Pick the operand layout that makes the transpose a bitcast.
1095 int64 rank = instruction->shape().rank();
1096 std::vector<int64> new_minor_to_major(rank);
1097 for (int64 i = 0; i < rank; ++i) {
1098 int64 output_dim = LayoutUtil::Minor(output_layout, i);
1099 int64 operand_dim = instruction->dimensions(output_dim);
1100 new_minor_to_major[i] = operand_dim;
1101 }
1102 Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1103 TF_CHECK_OK(
1104 LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1105 return absl::make_unique<Layout>(operand_layout);
1106 }
1107
1108 return nullptr;
1109 }
1110
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64 operand_no)1111 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1112 const Layout& operand_layout, const HloInstruction* user,
1113 int64 operand_no) {
1114 const HloInstruction* operand = user->operand(operand_no);
1115
1116 CHECK(user->shape().IsArray() && operand->shape().IsArray());
1117
1118 if (!ShapeUtil::IsScalar(operand->shape()) &&
1119 operand->shape().rank() == user->shape().rank() &&
1120 !instruction_can_change_layout_func_(user)) {
1121 // Assign users the same layout as the operand.
1122 return absl::make_unique<Layout>(operand_layout);
1123 }
1124
1125 if (user->opcode() == HloOpcode::kReshape) {
1126 // Prefer the user layout that makes the reshape an bitcast. If any
1127 // dimension bound is 1 in the user shape, there may be several such
1128 // layouts. So if 'operand_layout' is the default layout, try if the
1129 // reshape is a bitcast when using the same layout. This may avoid copy
1130 // operations. For similar reasons, if the operand and output have the same
1131 // rank, try to match the outputs's layout to the operand.
1132 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1133 ShapeUtil::TrueRank(user->shape()) == 1) {
1134 // Don't assign a layout in case of R1 -> effective R1 reshape.
1135 return nullptr;
1136 }
1137 Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1138 operand->shape().element_type(),
1139 AsInt64Slice(operand->shape().dimensions()),
1140 LayoutUtil::MinorToMajor(operand_layout));
1141 Shape output_shape = user->shape();
1142 *output_shape.mutable_layout() =
1143 LayoutUtil::GetDefaultLayoutForShape(output_shape);
1144 auto aligned_user_shape =
1145 ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1146 if (aligned_user_shape) {
1147 auto user_layout = aligned_user_shape.value().layout();
1148 TF_CHECK_OK(
1149 LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1150 return absl::make_unique<Layout>(user_layout);
1151 }
1152 }
1153
1154 if (user->opcode() == HloOpcode::kTranspose) {
1155 // Pick the user layout that makes the transpose a bitcast.
1156 int64 rank = user->shape().rank();
1157 std::vector<int64> new_minor_to_major(rank);
1158 auto inverse_dimensions = InversePermutation(user->dimensions());
1159 for (int64 i = 0; i < rank; ++i) {
1160 int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
1161 int64 user_dim = inverse_dimensions[operand_dim];
1162 new_minor_to_major[i] = user_dim;
1163 }
1164 Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1165 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1166 return absl::make_unique<Layout>(user_layout);
1167 }
1168
1169 return nullptr;
1170 }
1171
PropagateConstraints(LayoutConstraints * constraints)1172 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1173 // Gathers all initial constraints in a worklist and propagates them in
1174 // depth-first order. DFS order seems to be better than BFS because a
1175 // constraint is propagated as far as possible before propagating unrelated
1176 // constraints which makes it less likely that conflicting constraints will be
1177 // propagated to instructions. However, we should experiment with other orders
1178 // too.
1179 std::deque<const LayoutConstraint*> worklist;
1180
1181 // Lambda for moving newly added constraints to the worklist.
1182 auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1183 // Add constraints to the front of the deque for DFS ordering.
1184 for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1185 if (constraint->dfs()) {
1186 worklist.push_front(constraint);
1187 } else {
1188 worklist.push_back(constraint);
1189 }
1190 }
1191 };
1192 add_new_constraints_to_worklist();
1193
1194 while (!worklist.empty()) {
1195 const LayoutConstraint* layout_constraint = worklist.front();
1196 worklist.pop_front();
1197 VLOG(2) << "Propagating " << layout_constraint->ToString()
1198 << " to its neighbors.";
1199 if (auto* buffer_constraint =
1200 dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1201 TF_RETURN_IF_ERROR(
1202 PropagateBufferConstraint(*buffer_constraint, constraints));
1203 } else if (auto* operand_constraint =
1204 dynamic_cast<const OperandLayoutConstraint*>(
1205 layout_constraint)) {
1206 TF_RETURN_IF_ERROR(
1207 PropagateOperandConstraint(*operand_constraint, constraints));
1208 } else if (auto* result_constraint =
1209 dynamic_cast<const ResultLayoutConstraint*>(
1210 layout_constraint)) {
1211 TF_RETURN_IF_ERROR(
1212 PropagateResultConstraint(*result_constraint, constraints));
1213 } else {
1214 LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1215 }
1216
1217 add_new_constraints_to_worklist();
1218 }
1219 return Status::OK();
1220 }
1221
1222 namespace {
1223
1224 // Returns a vector containing all array-shaped uses (instruction and operand
1225 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1226 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1227 const LogicalBuffer& buffer,
1228 const TuplePointsToAnalysis& points_to_analysis) {
1229 CHECK(buffer.IsArray());
1230 std::vector<std::pair<const HloInstruction*, int64>> uses;
1231 for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1232 if (!buffer_alias.instruction()->shape().IsArray()) {
1233 continue;
1234 }
1235 // This alias must be the top-level (index == {}) of the instruction's
1236 // result because the instruction produces an array.
1237 CHECK(buffer_alias.index().empty());
1238
1239 // Add all uses of the instruction's output.
1240 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1241 for (int64 operand_no :
1242 user->OperandIndices(buffer_alias.instruction())) {
1243 uses.emplace_back(user, operand_no);
1244 }
1245 }
1246 }
1247 return uses;
1248 }
1249
1250 } // namespace
1251
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1252 Status LayoutAssignment::PropagateUseConstraintToDefs(
1253 const ShapeLayout& shape_layout, const HloInstruction* instruction,
1254 LayoutConstraints* constraints) {
1255 // Try to set all logical buffers which may be sources of the given operand to
1256 // match the given layout.
1257 const PointsToSet& points_to_set =
1258 constraints->points_to_analysis().GetPointsToSet(instruction);
1259 return points_to_set.ForEachElementWithStatus(
1260 [&shape_layout, constraints](
1261 const ShapeIndex& index,
1262 const PointsToSet::BufferList& buffers) -> Status {
1263 if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1264 for (const LogicalBuffer* buffer : buffers) {
1265 if (constraints->BufferLayout(*buffer) == nullptr &&
1266 buffer->shape().IsArray()) {
1267 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1268 ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1269 *buffer, /*mandatory=*/true));
1270 }
1271 }
1272 }
1273 return Status::OK();
1274 });
1275 }
1276
1277 namespace {
1278 // A transpose or a reshape that only changes trivial dimensions have meaningful
1279 // layouts that are valuable to propagate in a depthfirst manner to avoid
1280 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo,bool forward_propagation=true)1281 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
1282 bool forward_propagation = true) {
1283 switch (hlo.opcode()) {
1284 case HloOpcode::kFusion:
1285 return hlo.IsCustomFusion();
1286 case HloOpcode::kGather:
1287 return true;
1288 case HloOpcode::kReshape:
1289 return hlo.operand(0)->shape().rank() == 1 ||
1290 (forward_propagation &&
1291 std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
1292 case HloOpcode::kScatter:
1293 case HloOpcode::kTranspose:
1294 return true;
1295 default:
1296 return false;
1297 }
1298 }
1299
1300 } // namespace
1301
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1302 Status LayoutAssignment::PropagateOperandConstraint(
1303 const OperandLayoutConstraint& operand_constraint,
1304 LayoutConstraints* constraints) {
1305 // Try to set the layout of the logical buffers in the given operand to match
1306 // the constrained layout. This avoids copies.
1307 TF_RETURN_IF_ERROR(
1308 PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1309 operand_constraint.operand(), constraints));
1310
1311 // For array-shaped operands and user instructions try to pick a minimum cost
1312 // layout. For example, if the operand of an elementwise instruction is
1313 // constrained to a certain layout we want the output of the instruction to
1314 // have the same layout.
1315 //
1316 // If the user is not array-shaped, we still want to propagate the layout
1317 // to siblings if the instruction can't change layout. This is to represent
1318 // the information that non-layout-changing instructions should have the same
1319 // layout for the operands with the same ranks.
1320 const HloInstruction* operand = operand_constraint.operand();
1321 const HloInstruction* user = operand_constraint.instruction();
1322 if (!operand->shape().IsArray()) {
1323 return Status::OK();
1324 }
1325
1326 if (user->opcode() == HloOpcode::kAllReduce) {
1327 const auto shape_index =
1328 user->operand_count() == 1
1329 ? ShapeIndex()
1330 : ShapeIndex({operand_constraint.operand_no()});
1331 TF_ASSIGN_OR_RETURN(const LogicalBuffer* buffer,
1332 constraints->points_to_analysis().GetBufferDefinedAt(
1333 user, shape_index));
1334 const BufferLayoutConstraint* constraint =
1335 constraints->GetBufferLayoutConstraint(*buffer);
1336 if (constraint == nullptr) {
1337 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1338 operand_constraint.shape_layout().layout(), *buffer,
1339 /*mandatory=*/false));
1340 }
1341 }
1342 if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
1343 return Status::OK();
1344 }
1345
1346 // Only try to choose a low cost layout if the instruction 'user' defines its
1347 // output (ie, doesn't forward a buffer from elsewhere).
1348 if (constraints->OperandBufferForwarded(user,
1349 operand_constraint.operand_no())) {
1350 return Status::OK();
1351 }
1352
1353 int64 operand_rank = operand->shape().rank();
1354 if (operand_rank <= 1) {
1355 return Status::OK();
1356 }
1357
1358 // Propagate layouts between operands of the same instruction. This is a
1359 // constraint on non-layout-changing instructions.
1360 if (!instruction_can_change_layout_func_(user)) {
1361 // Only propgate the layout of the largest concatenate operand.
1362 if (user->opcode() == HloOpcode::kConcatenate) {
1363 for (int64 operand_no = 0; operand_no < user->operand_count();
1364 ++operand_no) {
1365 const HloInstruction* sibling = user->operand(operand_no);
1366 if (sibling == operand) {
1367 continue;
1368 }
1369 if (sibling->shape().dimensions(user->concatenate_dimension()) >
1370 operand->shape().dimensions(user->concatenate_dimension())) {
1371 return Status::OK();
1372 }
1373 }
1374 }
1375 // Make sure all siblings have the same layout as the operand.
1376 for (int64 operand_no = 0; operand_no < user->operand_count();
1377 ++operand_no) {
1378 if (user->operand(operand_no) == operand) {
1379 continue;
1380 }
1381 const HloInstruction* sibling = user->operand(operand_no);
1382 const int64 sibling_rank = sibling->shape().rank();
1383 if (sibling_rank <= 1) {
1384 continue;
1385 }
1386 if (operand_rank != sibling_rank) {
1387 continue;
1388 }
1389 const OperandLayoutConstraint* constraint =
1390 constraints->GetOperandLayoutConstraint(user, operand_no);
1391 if (constraint != nullptr) {
1392 // Due to the DFS of the propagation we can end up here when operand_no
1393 // has a layout set that hasn't been propagated yet (is still on the
1394 // stack of layouts to propagate).
1395 // We can continue here and leave the operands with different layouts,
1396 // as we will either:
1397 // - overwrite the current operand when the DFS gets back to propagating
1398 // operand(operand_no) to its siblings
1399 // - overwrite operand(operand_no)'s layout with a mandatory layout if
1400 // we continue to propagate our layout to the result, and then
1401 // backwards into all operands (if the result is an array of rank > 1)
1402 continue;
1403 }
1404 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1405 operand_constraint.shape_layout().layout(), user, operand_no,
1406 /*mandatory=*/false));
1407 }
1408 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1409 user->shape(),
1410 [&](const Shape& subshape, const ShapeIndex& shape_index) {
1411 if (subshape.IsTuple()) {
1412 return Status::OK();
1413 }
1414 if (subshape.rank() <= 1) {
1415 return Status::OK();
1416 }
1417
1418 // Assign the right layout to input fusion of higher rank reduce
1419 // operations.
1420 if (subshape.rank() != operand->shape().rank()) {
1421 return Status::OK();
1422 }
1423 if (!constraints->points_to_analysis()
1424 .InstructionDefinesBufferAtIndex(user, shape_index)) {
1425 return Status::OK();
1426 }
1427 // TODO(b/67641796): Are there cases except fusion that use this code
1428 // path?
1429 TF_ASSIGN_OR_RETURN(
1430 const LogicalBuffer* buffer,
1431 constraints->points_to_analysis().GetBufferDefinedAt(
1432 user, shape_index));
1433 // Make sure the output has the same layout as the operand.
1434 const BufferLayoutConstraint* constraint =
1435 constraints->GetBufferLayoutConstraint(*buffer);
1436 // If we already have a constraint for the buffer it was assigned but
1437 // hasn't propagated yet. This can happen with diamond-shaped graphs
1438 // where one path is first evaluated in depth-first order (we're here)
1439 // and the other path is propagated later. We don't set the layout
1440 // here as it will always be overwritten later.
1441 if (constraint == nullptr) {
1442 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1443 operand_constraint.shape_layout().layout(), *buffer,
1444 /*mandatory=*/false));
1445 }
1446 return Status::OK();
1447 }));
1448 return Status::OK();
1449 }
1450 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1451 user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1452 if (subshape.IsTuple()) {
1453 return Status::OK();
1454 }
1455 if (subshape.rank() <= 1) {
1456 return Status::OK();
1457 }
1458 if (!constraints->points_to_analysis().InstructionDefinesBufferAtIndex(
1459 user, shape_index)) {
1460 return Status::OK();
1461 }
1462 TF_ASSIGN_OR_RETURN(
1463 const LogicalBuffer* buffer,
1464 constraints->points_to_analysis().GetBufferDefinedAt(user,
1465 shape_index));
1466 if (constraints->BufferLayout(*buffer) == nullptr ||
1467 !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1468 std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1469 operand_constraint.shape_layout().layout(), user,
1470 operand_constraint.operand_no());
1471 if (layout != nullptr) {
1472 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1473 *layout, *buffer,
1474 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1475 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1476 }
1477 }
1478 return Status::OK();
1479 }));
1480 return Status::OK();
1481 }
1482
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1483 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1484 const BufferLayoutConstraint& buffer_constraint,
1485 LayoutConstraints* constraints) {
1486 VLOG(5) << "PropagateBufferConstraintToOperands: "
1487 << buffer_constraint.ToString();
1488 const LogicalBuffer& buffer = buffer_constraint.buffer();
1489
1490 const HloInstruction* instruction = buffer.instruction();
1491 if (IsAtMostRank1(instruction->shape())) {
1492 return Status::OK();
1493 }
1494
1495 if (instruction->opcode() == HloOpcode::kAllReduce) {
1496 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1497 buffer_constraint.layout(), instruction,
1498 instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1499 /*mandatory=*/true));
1500 return Status::OK();
1501 }
1502 for (int64 operand_no = 0; operand_no < instruction->operand_count();
1503 ++operand_no) {
1504 const HloInstruction* operand = instruction->operand(operand_no);
1505 if (IsAtMostRank1(operand->shape())) {
1506 continue;
1507 }
1508 if (!instruction_can_change_layout_func_(instruction)) {
1509 // Copy the layout to the operand.
1510 if (buffer.IsArray() && operand->shape().IsArray() &&
1511 operand->shape().rank() ==
1512 LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1513 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1514 buffer_constraint.layout(), instruction, operand_no,
1515 /*mandatory=*/true));
1516 }
1517 } else {
1518 if (!buffer.IsTopLevel() ||
1519 !instruction->operand(operand_no)->shape().IsArray()) {
1520 continue; // Don't touch buffers that are internal to a tuple.
1521 }
1522 VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1523 << instruction->ToShortString();
1524 // Assign a layout if there is no constraint already.
1525 const OperandLayoutConstraint* constraint =
1526 constraints->GetOperandLayoutConstraint(instruction, operand_no);
1527 if (constraint == nullptr || !constraint->mandatory()) {
1528 std::unique_ptr<Layout> operand_layout =
1529 ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1530 instruction, operand_no);
1531 if (operand_layout != nullptr) {
1532 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1533 *operand_layout, instruction, operand_no, /*mandatory=*/false,
1534 /*dfs=*/
1535 InstructionShouldPropagateDepthFirst(
1536 *instruction, /*forward_propagation=*/false)));
1537 }
1538 } else {
1539 VLOG(6) << "Operand already has a constraint "
1540 << constraint->ToString();
1541 }
1542 }
1543 }
1544 return Status::OK();
1545 }
1546
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1547 Status LayoutAssignment::PropagateBufferConstraint(
1548 const BufferLayoutConstraint& buffer_constraint,
1549 LayoutConstraints* constraints) {
1550 // Only propagate array layouts.
1551 const LogicalBuffer& buffer = buffer_constraint.buffer();
1552 if (!buffer.IsArray()) {
1553 return Status::OK();
1554 }
1555 TF_RETURN_IF_ERROR(
1556 PropagateBufferConstraintToUses(buffer_constraint, constraints));
1557 return PropagateBufferConstraintToOperands(buffer_constraint, constraints);
1558 }
1559
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1560 Status LayoutAssignment::PropagateBufferConstraintToUses(
1561 const BufferLayoutConstraint& buffer_constraint,
1562 LayoutConstraints* constraints) {
1563 const LogicalBuffer& buffer = buffer_constraint.buffer();
1564 TF_RET_CHECK(buffer.IsArray());
1565
1566 // Propagate the layout to all array uses of the logical buffer. This skips
1567 // uses of the buffer where the buffer is the element of a tuple.
1568 for (const auto& user_operand_no :
1569 GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1570 const HloInstruction* user = user_operand_no.first;
1571 int64 operand_no = user_operand_no.second;
1572 // Only add an operand constraint if the user does not forward the buffer
1573 // because this case is not handled is SetOperandLayout.
1574 if (constraints->OperandLayout(user, operand_no) == nullptr &&
1575 !constraints->OperandBufferForwarded(user, operand_no)) {
1576 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1577 buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1578 }
1579 }
1580
1581 // Propagate to backedges of kWhile.
1582 CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1583 if (node.caller_callsites().size() != 1) {
1584 return Status::OK();
1585 }
1586 const HloInstruction* parent = node.caller_callsites()[0].instruction();
1587 if (parent->opcode() != HloOpcode::kWhile) {
1588 return Status::OK();
1589 }
1590
1591 for (HloInstruction* user : buffer.instruction()->users()) {
1592 if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1593 continue;
1594 }
1595 if (user->parent()->root_instruction() == user) {
1596 VLOG(3) << "Propagating layout through backedge"
1597 << buffer_constraint.layout().ToString();
1598 int64 index = user->operand_index(buffer.instruction());
1599 TF_ASSIGN_OR_RETURN(
1600 auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
1601 user->parent()->parameter_instruction(0), {index}));
1602
1603 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1604 buffer_constraint.layout(), *buffer, /*mandatory=*/false));
1605 }
1606 }
1607
1608 return Status::OK();
1609 }
1610
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1611 Status LayoutAssignment::PropagateResultConstraint(
1612 const ResultLayoutConstraint& layout_constraint,
1613 LayoutConstraints* constraints) {
1614 // Propagate the use constraint of the root instruction up to the logical
1615 // buffers which make up the result.
1616 return PropagateUseConstraintToDefs(
1617 layout_constraint.shape_layout(),
1618 constraints->computation()->root_instruction(), constraints);
1619 }
1620
1621 namespace {
1622
1623 // Infers the layout of the array at the given index in the given instruction's
1624 // output using points-to analysis. Precondition: The given instruction must
1625 // not produce this array value (that is, the array is forwarded from the
1626 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1627 StatusOr<Layout> InferArrayLayout(
1628 const TuplePointsToAnalysis& points_to_analysis,
1629 HloInstruction* instruction, const ShapeIndex& index) {
1630 // This function should only be called for array shapes which don't yet have
1631 // layouts.
1632 const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1633 TF_RET_CHECK(subshape.IsArray());
1634 TF_RET_CHECK(!subshape.has_layout());
1635
1636 // The instruction should not define the buffer at this index.
1637 TF_RET_CHECK(
1638 !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1639 << instruction->ToString();
1640
1641 const auto& source_buffers =
1642 points_to_analysis.GetPointsToSet(instruction).element(index);
1643 TF_RET_CHECK(!source_buffers.empty());
1644
1645 // Verify the layout is the same for every LogicalBuffer which this location
1646 // ('instruction' and 'index') points to.
1647 const Layout* first_buffer_layout = nullptr;
1648 for (const LogicalBuffer* source_buffer : source_buffers) {
1649 if (!source_buffer->shape().has_layout()) {
1650 // This should not happen because we've assigned layouts to all
1651 // instructions preceding this one.
1652 return InternalError("LogicalBuffer %s does not have a layout",
1653 source_buffer->ToString());
1654 }
1655
1656 if (first_buffer_layout == nullptr) {
1657 first_buffer_layout = &source_buffer->shape().layout();
1658 } else if (!Layout::Equal().MinorToMajorOnly()(
1659 source_buffer->shape().layout(), *first_buffer_layout)) {
1660 // The points-to set is ambiguous for this index and the different source
1661 // buffers have different layouts. This case is possible in valid XLA
1662 // computations because we do not propagate BufferLayoutConstraints to all
1663 // LogicalBuffers which may alias the constrained LogicalBuffer at some
1664 // point in the computation.
1665 return FailedPrecondition(
1666 "Array at index {%s} in instruction %s aliases buffers %s "
1667 "and %s which have different layouts",
1668 absl::StrJoin(index, ","), instruction->name(),
1669 source_buffers[0]->ToString(), source_buffer->ToString());
1670 }
1671 }
1672
1673 return *first_buffer_layout;
1674 }
1675
1676 // For fusion instructions, set the layout of each fused parameter instruction
1677 // to match the layout of its corresponding fusion instruction operand. Also,
1678 // set the layout of the fused root to match the layout of the fusion
1679 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1680 Status SetFusionLayouts(HloInstruction* fusion) {
1681 TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1682 for (auto* fused_instruction :
1683 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1684 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1685 const HloInstruction* fusion_operand =
1686 fusion->operand(fused_instruction->parameter_number());
1687 DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1688 fused_instruction->shape()));
1689 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1690 fusion_operand->shape(), fused_instruction->mutable_shape()));
1691 } else if (fused_instruction == fusion->fused_expression_root()) {
1692 // The layout of the root of the fused expression must match the fusion
1693 // instruction layout.
1694 DCHECK(
1695 ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1696 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1697 fusion->shape(), fused_instruction->mutable_shape()));
1698 } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1699 // A GTE inherits its layout from its operand (which should ultimately be
1700 // a parameter).
1701 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1702 fused_instruction->operand(0)->shape().tuple_shapes(
1703 fused_instruction->tuple_index()),
1704 fused_instruction->mutable_shape()));
1705 } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1706 // Give constants the layout of their literal.
1707 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1708 fused_instruction->literal().shape(),
1709 fused_instruction->mutable_shape()));
1710 } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1711 // Nop; leave the infeed layout alone.
1712 } else if (!fusion->IsCustomFusion()) {
1713 // Other instructions don't have layouts inside of fusion nodes.
1714 // But do not clear layouts for other instructions in custom fusion nodes.
1715 LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1716 }
1717 }
1718
1719 return Status::OK();
1720 }
1721
1722 } // namespace
1723
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1724 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1725 HloComputation* computation) {
1726 VLOG(2) << "Assigning layouts to computation: " << computation->name();
1727 XLA_VLOG_LINES(2, computation->ToString());
1728 XLA_VLOG_LINES(2, constraints.ToString());
1729
1730 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1731 LayoutUtil::ClearLayout(instruction->mutable_shape());
1732
1733 // Set the layouts of the array shapes this instruction defines as indicated
1734 // by the respective BufferLayoutConstraints. Any array shapes in the output
1735 // of the instruction which are not defined by the instruction (eg, array
1736 // elements in a Tuple instruction) will be assigned below via inference.
1737 for (const LogicalBuffer* buffer :
1738 constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1739 instruction)) {
1740 if (!buffer->shape().IsArray()) {
1741 continue;
1742 }
1743
1744 TF_RET_CHECK(buffer->instruction() == instruction);
1745 const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1746 TF_RET_CHECK(buffer_layout != nullptr);
1747
1748 if (instruction->opcode() == HloOpcode::kConstant) {
1749 // For constants, we also need to change the layout of the internal
1750 // literal.
1751 instruction->RelayoutConstant(*buffer_layout, buffer->index());
1752 } else {
1753 Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1754 instruction->mutable_shape(), buffer->index());
1755 *buffer_subshape->mutable_layout() = *buffer_layout;
1756 }
1757 }
1758
1759 // Any remaining layouts in the output of the instruction must be
1760 // inferrable using points-to analysis.
1761 TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1762 instruction->mutable_shape(),
1763 [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1764 if (subshape->has_layout() || !subshape->IsArray()) {
1765 return Status::OK();
1766 }
1767 // Set Layout of subshape to match layout of LogicalBuffer which
1768 // produces it.
1769 TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1770 InferArrayLayout(constraints.points_to_analysis(),
1771 instruction, index));
1772 return Status::OK();
1773 }));
1774
1775 // Create a copy of an operand if the operand instruction's layout does not
1776 // match the use constraint (OperandLayoutConstraint).
1777 for (int64 operand_no = 0; operand_no < instruction->operand_count();
1778 ++operand_no) {
1779 const ShapeLayout* operand_layout =
1780 constraints.OperandLayout(instruction, operand_no);
1781 if (operand_layout != nullptr) {
1782 TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1783 instruction, operand_no));
1784 }
1785 }
1786
1787 // Fusion instructions require some layouts to be set on fused instructions
1788 // inside the fusion instruction.
1789 if (instruction->opcode() == HloOpcode::kFusion) {
1790 TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1791 }
1792
1793 // Execute extra verification step once the layout has been finalized.
1794 TF_RETURN_IF_ERROR(Verify(instruction));
1795
1796 // Shape must be valid.
1797 TF_RETURN_IF_ERROR(
1798 ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1799
1800 // Verify all layouts in the shape have been set.
1801 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1802 }
1803 return Status::OK();
1804 }
1805
CalculateComputationLayout(HloComputation * computation)1806 Status LayoutAssignment::CalculateComputationLayout(
1807 HloComputation* computation) {
1808 ComputationLayout computation_layout(computation->ComputeProgramShape(),
1809 /*ignore_layouts=*/false);
1810 InsertOrDie(&computation_layouts_, computation, computation_layout);
1811 VLOG(2) << " Calculated ComputationLayout = "
1812 << computation_layout.ToString();
1813 return Status::OK();
1814 }
1815
ClearComputationLayouts(HloComputation * computation)1816 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1817 // Clear existing layouts of the instructions. All layouts must be assigned
1818 // by the LayoutAssignment pass, except for those on parameters, the
1819 // computation result, and a couple special cases. The former two are
1820 // specified in computation_layout. Clearing the layouts here avoids hiding
1821 // potential bugs in the layout assignment pass that may accidentally use the
1822 // existing layout.
1823 for (HloInstruction* instruction : computation->instructions()) {
1824 if (instruction->opcode() == HloOpcode::kBitcast) {
1825 // bitcasts are inherently layout sensitive and so a bitcast instruction
1826 // present in the IR before layout assignment is a bug.
1827 return InternalError(
1828 "Unexpected bitcast operation seen during layout assignment: %s.",
1829 instruction->ToString());
1830 }
1831 // Some instructions carry mandatory layouts in their shape.
1832 if (instruction->opcode() != HloOpcode::kInfeed &&
1833 !IsLayoutConstrainedCustomCall(instruction) &&
1834 !IsLayoutConstrainedCollective(instruction)) {
1835 LayoutUtil::ClearLayout(instruction->mutable_shape());
1836 }
1837 }
1838 return Status::OK();
1839 }
1840
RunOnComputation(ComputationLayout * computation_layout,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1841 Status LayoutAssignment::RunOnComputation(
1842 ComputationLayout* computation_layout, HloComputation* computation,
1843 ChannelLayoutConstraints* channel_constraints) {
1844 VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1845 << ")";
1846
1847 // Must be run before clearing layouts.
1848 TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1849
1850 TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1851 if (computation_layout != nullptr) {
1852 auto it = computation_layouts_.find(computation);
1853 if (it == computation_layouts_.end()) {
1854 VLOG(2) << " New ComputationLayout = " << computation_layout->ToString();
1855 computation_layouts_.emplace(computation, *computation_layout);
1856 } else {
1857 TF_RET_CHECK(computation_layout == &it->second ||
1858 computation_layout == entry_computation_layout_);
1859 VLOG(2) << " Existing ComputationLayout = "
1860 << computation_layout->ToString();
1861 }
1862 } else {
1863 VLOG(2) << " No ComputationLayout specified (will be calculated)";
1864 }
1865
1866 // Construct LayoutConstraints with all layout constraints of the computation.
1867 LayoutConstraints constraints(*points_to_analysis_, computation);
1868
1869 // Add constraints required for correctness on all backends (eg, entry
1870 // parameter layout constraints).
1871 TF_RETURN_IF_ERROR(AddMandatoryConstraints(
1872 computation_layout, channel_constraints, computation, &constraints));
1873
1874 // Add any backend-specific constraints.
1875 TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
1876
1877 // Propagates layouts from mandatory and backend constraints.
1878 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1879
1880 // Prior to applying default layouts, we take note of all HLO instructions
1881 // which lack a layout constraint.
1882 for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
1883 unconstrained_layout_instructions_.insert(
1884 points_to_analysis_->GetBuffer(buffer_id).instruction());
1885 }
1886
1887 // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
1888 // layout and propagate the change.
1889 while (!constraints.unconstrained_buffer_ids().empty()) {
1890 int unconstrained_count = constraints.unconstrained_buffer_ids().size();
1891
1892 // Arbitrarily pick the first unconstrained buffer and give it the default
1893 // layout (or the literal layout, in case of constants). By construction
1894 // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
1895 const LogicalBuffer& buffer = points_to_analysis_->GetBuffer(
1896 *constraints.unconstrained_buffer_ids().begin());
1897 const HloInstruction* instruction = buffer.instruction();
1898 Layout new_layout =
1899 instruction->opcode() == HloOpcode::kConstant
1900 ? ShapeUtil::GetSubshape(instruction->literal().shape(),
1901 buffer.index())
1902 .layout()
1903 : GetUnconstrainedLayout(buffer);
1904 TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
1905 /*mandatory=*/false));
1906
1907 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1908
1909 // To verify progress has been made, check that the number of unconstrained
1910 // buffers has been reduced.
1911 CHECK_LT(constraints.unconstrained_buffer_ids().size(),
1912 unconstrained_count);
1913 }
1914 // All logical buffers should have constraints at this point. All that
1915 // remains is assign the constraints to the buffers and infer layouts for
1916 // aliased buffers.
1917 TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
1918
1919 // If the computation layout wasn't specified, now it is the time to compute
1920 // it according to the parameters and root instruction layouts.
1921 // This allows the first pass through this API to record the best flowing
1922 // layout to parameters and root instruction.
1923 if (computation_layout == nullptr) {
1924 TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
1925 }
1926
1927 // Record the layouts assigned for any communication ops in
1928 // channel_constraints so that they are constrained for future modules.
1929 if (channel_constraints != nullptr) {
1930 TF_RETURN_IF_ERROR(
1931 ConstrainChannelLayouts(computation, channel_constraints));
1932 }
1933
1934 // Copy the root instruction's result if its layout does not match the result
1935 // layout constraint.
1936 if (constraints.ResultLayout() != nullptr) {
1937 // Layout assignment at this point only does minor-to-major assignment so
1938 // tiling info should be ignored here for comparison.
1939 if (!constraints.ResultLayout()->MatchesLayoutInShape(
1940 computation->root_instruction()->shape(),
1941 /*minor_to_major_only=*/true)) {
1942 if (conditional_mismatch_.count(computation) > 0) {
1943 *FindOrDie(computation_layouts_, computation).mutable_result_layout() =
1944 FindOrDie(conditional_mismatch_, computation).result_layout();
1945 }
1946 TF_ASSIGN_OR_RETURN(
1947 HloInstruction * new_root,
1948 CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
1949 computation->root_instruction()));
1950 computation->set_root_instruction(new_root);
1951 } else {
1952 // Copy the specified tiling info.
1953 auto assign_tiling = [&constraints](xla::Shape* subshape,
1954 const xla::ShapeIndex& index) {
1955 if (subshape->IsArray()) {
1956 const Shape& result_shape = ShapeUtil::GetSubshape(
1957 constraints.ResultLayout()->shape(), index);
1958 subshape->mutable_layout()->mutable_tiles()->assign(
1959 result_shape.layout().tiles().begin(),
1960 result_shape.layout().tiles().end());
1961 }
1962 };
1963 xla::ShapeUtil::ForEachMutableSubshape(
1964 computation->root_instruction()->mutable_shape(), assign_tiling);
1965 }
1966 }
1967 return Status::OK();
1968 }
1969
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1970 Status LayoutAssignment::ConstrainChannelLayouts(
1971 HloComputation* computation,
1972 ChannelLayoutConstraints* channel_constraints) {
1973 auto get_channel_constraints = [&](const HloInstruction* instruction) {
1974 return IsHostSendRecv(instruction) ? &host_channel_constraints_
1975 : channel_constraints;
1976 };
1977 // We go through the kRecvDone before. These must either impose their layout,
1978 // or find a matching one already existing (ConstrainChannel() returns
1979 // nullptr).
1980 for (HloInstruction* instruction : computation->instructions()) {
1981 if (instruction->opcode() == HloOpcode::kRecvDone) {
1982 const Layout* layout =
1983 get_channel_constraints(instruction)
1984 ->ConstrainChannel(
1985 *instruction->channel_id(),
1986 ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
1987 TF_RET_CHECK(layout == nullptr)
1988 << instruction->ToString()
1989 << " cannot constrain layout as it was set to "
1990 << LayoutUtil::HumanString(*layout);
1991 }
1992 }
1993 // After that we go through the kSend. These are likely going to have a kCopy
1994 // as operand (otherwise we add it), so in case the constrained layout does
1995 // not match, we can change the kCopy layout (and the kSend one as well).
1996 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1997 if (instruction->opcode() == HloOpcode::kSend) {
1998 HloInstruction* operand = instruction->mutable_operand(0);
1999 get_channel_constraints(instruction)
2000 ->ConstrainChannel(*instruction->channel_id(),
2001 operand->shape().layout());
2002 } else if (instruction->IsCrossModuleAllReduce()) {
2003 get_channel_constraints(instruction)
2004 ->ConstrainChannel(instruction->channel_id().value(),
2005 instruction->shape().layout());
2006 }
2007 }
2008 return Status::OK();
2009 }
2010
PropagateMemorySpace(HloModule * module)2011 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
2012 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
2013 for (const auto& buffer : alias_analysis->buffers()) {
2014 // First go through values to collect the memory spaces.
2015 int64 buffer_memory_space = Layout::kDefaultMemorySpace;
2016 for (auto value : buffer.values()) {
2017 const Shape& defining_shape = value->defining_position().shape();
2018 int64 memory_space = defining_shape.layout().memory_space();
2019 if (memory_space != Layout::kDefaultMemorySpace) {
2020 if (buffer_memory_space != Layout::kDefaultMemorySpace &&
2021 memory_space != buffer_memory_space) {
2022 return InternalError(
2023 "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
2024 buffer.id(), value->ToShortString(), buffer_memory_space,
2025 memory_space);
2026 }
2027 buffer_memory_space = memory_space;
2028 }
2029 }
2030
2031 // If we encounter a memory space other than the default, then propagate all
2032 // the positions with the buffer's memory space.
2033 if (buffer_memory_space != Layout::kDefaultMemorySpace) {
2034 for (auto value : buffer.values()) {
2035 for (auto& position : value->positions()) {
2036 Shape* shape = ShapeUtil::GetMutableSubshape(
2037 position.instruction->mutable_shape(), position.index);
2038 shape->mutable_layout()->set_memory_space(buffer_memory_space);
2039 }
2040 }
2041 }
2042 }
2043 return Status::OK();
2044 }
2045
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2046 Status LayoutAssignment::PropagateComputationLayouts(
2047 HloComputation* computation, ComputationLayout* computation_layout) {
2048 ComputationLayout computed_computation_layout(
2049 computation->ComputeProgramShape(),
2050 /*ignore_layouts=*/false);
2051 for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2052 ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2053 bool needs_assign = false;
2054 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2055 param_layout->shape(),
2056 [&](const Shape& subshape, const ShapeIndex& shape_index) {
2057 if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2058 return Status::OK();
2059 }
2060 if (!subshape.has_layout()) {
2061 needs_assign = true;
2062 return Status::OK();
2063 }
2064 const auto& computed_subshape = ShapeUtil::GetSubshape(
2065 computed_computation_layout.parameter_shape(i), shape_index);
2066 if (subshape.layout() != computed_subshape.layout()) {
2067 return InternalError(
2068 "Assigned parameter shape %s does not match layout of "
2069 "computation shape: %s",
2070 computed_computation_layout.ToString(),
2071 computation_layout->ToString());
2072 }
2073 return Status::OK();
2074 }));
2075 if (needs_assign) {
2076 VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2077 << computation->name() << ": "
2078 << computed_computation_layout.parameter_layout(i).ToString();
2079 *param_layout = computed_computation_layout.parameter_layout(i);
2080 }
2081 }
2082 ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2083 if (!result_layout->LayoutIsSet()) {
2084 VLOG(4) << "Assigning result layout of computation " << computation->name()
2085 << ": " << computed_computation_layout.result_layout().ToString();
2086 *result_layout = computed_computation_layout.result_layout();
2087 } else {
2088 TF_RET_CHECK(
2089 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2090 computed_computation_layout.result_layout().shape(),
2091 result_layout->shape()));
2092 }
2093 return Status::OK();
2094 }
2095
Run(HloModule * module)2096 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
2097 VLOG(2) << "Running layout assignment on module " << module->name();
2098 TF_RETURN_IF_ERROR(Init());
2099 call_graph_ = CallGraph::Build(module);
2100 auto computations = module->computations();
2101
2102 // Add copy to the operand of Send instructions, since we cannot call
2103 // SetOperandLayout on Send instructions as it aliases its input to the
2104 // output.
2105 //
2106 // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2107 // operand buffers that aliases with the output.
2108 for (HloComputation* computation : module->computations()) {
2109 for (HloInstruction* instruction :
2110 computation->MakeInstructionPostOrder()) {
2111 if (instruction->opcode() == HloOpcode::kSend) {
2112 TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2113 }
2114 }
2115 }
2116
2117 // Clone Conditional computations with multiple callsites.
2118 for (HloComputation* computation : computations) {
2119 CallGraphNode& node = call_graph_->GetNode(computation);
2120 if (node.caller_callsites().size() == 1) {
2121 continue;
2122 }
2123 if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2124 return caller.instruction()->opcode() == HloOpcode::kConditional;
2125 })) {
2126 continue;
2127 }
2128 for (int64 i = 0; i < node.caller_callsites().size() - 1; ++i) {
2129 HloInstruction* caller = node.caller_callsites()[i].instruction();
2130 if (caller->opcode() == HloOpcode::kConditional) {
2131 for (int64 k = 0; k < caller->branch_count(); ++k) {
2132 if (computation == caller->branch_computation(k)) {
2133 caller->set_branch_computation(
2134 k, module->AddEmbeddedComputation(computation->Clone()));
2135 break;
2136 }
2137 }
2138 }
2139 }
2140 }
2141
2142 // Verify computation layout is sane.
2143 const HloComputation* entry = module->entry_computation();
2144 TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2145 entry->num_parameters());
2146 for (int64 i = 0; i < entry->num_parameters(); ++i) {
2147 TF_RET_CHECK(
2148 ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2149 entry->parameter_instruction(i)->shape()));
2150 }
2151 TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2152 entry->root_instruction()->shape()));
2153
2154 // We do two passes. The first one we pass a nullptr ComputationLayout to
2155 // the RunOnComputation() calls (for non entry computations), and we register
2156 // the ComputationLayout which are naturally flowing in DFS fashion to the
2157 // parameters and root instruction.
2158 // Walking in DFS mode though, means that we can end up with incorrect layouts
2159 // when seen from an outer instruction, which has across-computation
2160 // constraints to impose.
2161 // For example, the kWhile instruction needs to enforce the same layouts for
2162 // the parameters and root of the body, as well as the condition parameters.
2163 // Similarly, the kConditional instruction needs to enforce the same layouts
2164 // for the root of the true and false computations.
2165 // So in the first pass, while allowing the layouts to flow to parameters and
2166 // root, we also fix up the eventually inconsistent ComputationLayout, which
2167 // will be then made mandatory by the second pass.
2168 for (int64 i = 0; i < 2; ++i) {
2169 VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2170 TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
2171 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2172 TuplePointsToAnalysis::Run(module));
2173 points_to_analysis_ = std::move(points_to_analysis);
2174 for (auto* computation : module->MakeComputationPostOrder()) {
2175 if (computation->IsFusionComputation()) {
2176 continue;
2177 }
2178 if (computation == module->entry_computation()) {
2179 TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_,
2180 module->entry_computation(),
2181 channel_layout_constraints_));
2182 } else {
2183 ComputationLayout* computation_layout =
2184 (i == 0 || conditional_mismatch_.count(computation) > 0)
2185 ? nullptr
2186 : &FindOrDie(computation_layouts_, computation);
2187 TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation,
2188 channel_layout_constraints_));
2189 }
2190 }
2191 }
2192 TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2193 entry_computation_layout_));
2194
2195 TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2196
2197 TF_RETURN_IF_ERROR(CheckLayouts(module));
2198
2199 // All layouts are reset then reassigned by this pass.
2200 return true;
2201 }
2202
2203 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2204 bool LayoutAssignment::InstructionCanChangeLayout(
2205 const HloInstruction* instruction) {
2206 switch (instruction->opcode()) {
2207 case HloOpcode::kAbs:
2208 case HloOpcode::kAdd:
2209 case HloOpcode::kAddDependency:
2210 case HloOpcode::kAnd:
2211 case HloOpcode::kAtan2:
2212 case HloOpcode::kBitcastConvert:
2213 case HloOpcode::kCeil:
2214 case HloOpcode::kClamp:
2215 case HloOpcode::kClz:
2216 case HloOpcode::kCompare:
2217 case HloOpcode::kComplex:
2218 case HloOpcode::kConcatenate:
2219 case HloOpcode::kConditional:
2220 case HloOpcode::kConvert:
2221 case HloOpcode::kCos:
2222 case HloOpcode::kAllGather:
2223 case HloOpcode::kAllToAll:
2224 case HloOpcode::kCollectivePermute:
2225 case HloOpcode::kDivide:
2226 case HloOpcode::kDynamicSlice:
2227 case HloOpcode::kDynamicUpdateSlice:
2228 case HloOpcode::kExp:
2229 case HloOpcode::kExpm1:
2230 case HloOpcode::kFft:
2231 case HloOpcode::kFloor:
2232 case HloOpcode::kImag:
2233 case HloOpcode::kIsFinite:
2234 case HloOpcode::kLog:
2235 case HloOpcode::kLog1p:
2236 case HloOpcode::kLogistic:
2237 case HloOpcode::kMap:
2238 case HloOpcode::kMaximum:
2239 case HloOpcode::kMinimum:
2240 case HloOpcode::kMultiply:
2241 case HloOpcode::kNegate:
2242 case HloOpcode::kNot:
2243 case HloOpcode::kOr:
2244 case HloOpcode::kXor:
2245 case HloOpcode::kPad:
2246 case HloOpcode::kPower:
2247 case HloOpcode::kReal:
2248 case HloOpcode::kReducePrecision:
2249 case HloOpcode::kReduceWindow:
2250 case HloOpcode::kRemainder:
2251 case HloOpcode::kReverse:
2252 case HloOpcode::kRoundNearestAfz:
2253 case HloOpcode::kRsqrt:
2254 case HloOpcode::kScatter:
2255 case HloOpcode::kSelect:
2256 case HloOpcode::kSelectAndScatter:
2257 case HloOpcode::kShiftLeft:
2258 case HloOpcode::kShiftRightArithmetic:
2259 case HloOpcode::kShiftRightLogical:
2260 case HloOpcode::kSign:
2261 case HloOpcode::kSin:
2262 case HloOpcode::kSlice:
2263 case HloOpcode::kSort:
2264 case HloOpcode::kSqrt:
2265 case HloOpcode::kCbrt:
2266 case HloOpcode::kSubtract:
2267 case HloOpcode::kTanh:
2268 case HloOpcode::kPopulationCount:
2269 case HloOpcode::kTriangularSolve:
2270 case HloOpcode::kCholesky:
2271 case HloOpcode::kTupleSelect:
2272 case HloOpcode::kWhile:
2273 case HloOpcode::kSetDimensionSize:
2274 // AllReduce is variadic so it needs to be careful to assign the same layout
2275 // to the corresponding input argument and Tuple index.
2276 case HloOpcode::kAllReduce:
2277 return false;
2278 case HloOpcode::kBatchNormGrad:
2279 case HloOpcode::kBatchNormInference:
2280 case HloOpcode::kBatchNormTraining:
2281 case HloOpcode::kBitcast:
2282 case HloOpcode::kBroadcast:
2283 case HloOpcode::kCall:
2284 case HloOpcode::kCollectivePermuteStart:
2285 case HloOpcode::kCollectivePermuteDone:
2286 case HloOpcode::kConstant:
2287 case HloOpcode::kConvolution:
2288 case HloOpcode::kCopy:
2289 case HloOpcode::kCopyStart:
2290 case HloOpcode::kCopyDone:
2291 case HloOpcode::kCustomCall:
2292 case HloOpcode::kDomain:
2293 case HloOpcode::kDot:
2294 case HloOpcode::kFusion:
2295 case HloOpcode::kGather:
2296 case HloOpcode::kGetTupleElement:
2297 case HloOpcode::kInfeed:
2298 case HloOpcode::kIota:
2299 case HloOpcode::kOutfeed:
2300 case HloOpcode::kParameter:
2301 case HloOpcode::kPartitionId:
2302 case HloOpcode::kRecv:
2303 case HloOpcode::kRecvDone:
2304 case HloOpcode::kReduce:
2305 case HloOpcode::kReplicaId:
2306 case HloOpcode::kReshape:
2307 case HloOpcode::kDynamicReshape:
2308 case HloOpcode::kRng:
2309 case HloOpcode::kRngBitGenerator:
2310 case HloOpcode::kRngGetAndUpdateState:
2311 case HloOpcode::kSend:
2312 case HloOpcode::kSendDone:
2313 case HloOpcode::kAfterAll:
2314 case HloOpcode::kTrace:
2315 case HloOpcode::kTranspose:
2316 case HloOpcode::kTuple:
2317 case HloOpcode::kGetDimensionSize:
2318 return true;
2319 }
2320 }
2321
2322 /* static */
IsAtMostRank1(const Shape & shape)2323 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2324 if (shape.IsArray()) {
2325 return shape.rank() <= 1;
2326 }
2327 return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2328 return IsAtMostRank1(subshape);
2329 });
2330 }
2331
Init()2332 Status LayoutAssignment::Init() {
2333 computation_layouts_.clear();
2334 conditional_mismatch_.clear();
2335 *entry_computation_layout_ = saved_entry_computation_layout_;
2336 return Status::OK();
2337 }
2338
ClearPreviousPassSideEffects(HloModule * module)2339 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2340 VLOG(5) << "Clearing previous side effects";
2341 // Clear all the copies which have been added, and all the related
2342 // instructions (like GTE and tuples).
2343 int64 removed_copies = 0;
2344 for (HloComputation* computation : module->computations()) {
2345 for (HloInstruction* instruction :
2346 computation->MakeInstructionPostOrder()) {
2347 if (instruction->opcode() == HloOpcode::kCopy &&
2348 added_copies_.contains(instruction)) {
2349 VLOG(5) << "Removing added copy: " << instruction->ToString();
2350 TF_RETURN_IF_ERROR(
2351 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2352 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2353 ++removed_copies;
2354 }
2355 }
2356 }
2357 added_copies_.clear();
2358 unconstrained_layout_instructions_.clear();
2359 if (removed_copies > 0) {
2360 TupleSimplifier tuple_simplifier;
2361 HloDCE dce;
2362 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2363 TF_RETURN_IF_ERROR(dce.Run(module).status());
2364 call_graph_ = CallGraph::Build(module);
2365 }
2366 return Status::OK();
2367 }
2368
AddCopyForOperand(HloInstruction * instruction,int64 operand_number)2369 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2370 int64 operand_number) {
2371 HloInstruction* operand = instruction->mutable_operand(operand_number);
2372 if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2373 HloInstruction* copy =
2374 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2375 operand->shape(), HloOpcode::kCopy, operand));
2376 SetupCopiedInstruction(*operand, copy, {});
2377 LayoutUtil::ClearLayout(copy->mutable_shape());
2378 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2379 }
2380 return Status::OK();
2381 }
2382
2383 } // namespace xla
2384