1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_format.h"
32 #include "absl/strings/str_join.h"
33 #include "absl/types/span.h"
34 #include "tensorflow/compiler/xla/layout_util.h"
35 #include "tensorflow/compiler/xla/map_util.h"
36 #include "tensorflow/compiler/xla/service/computation_layout.h"
37 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
38 #include "tensorflow/compiler/xla/service/hlo_computation.h"
39 #include "tensorflow/compiler/xla/service/hlo_dce.h"
40 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
42 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
43 #include "tensorflow/compiler/xla/service/logical_buffer.h"
44 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
45 #include "tensorflow/compiler/xla/shape_layout.h"
46 #include "tensorflow/compiler/xla/shape_util.h"
47 #include "tensorflow/compiler/xla/status_macros.h"
48 #include "tensorflow/compiler/xla/statusor.h"
49 #include "tensorflow/compiler/xla/types.h"
50 #include "tensorflow/compiler/xla/util.h"
51 #include "tensorflow/compiler/xla/xla_data.pb.h"
52 #include "tensorflow/core/lib/core/errors.h"
53 #include "tensorflow/core/lib/core/status.h"
54 #include "tensorflow/core/platform/logging.h"
55 #include "tensorflow/core/platform/protobuf.h"
56
57 namespace xla {
58
operator <<(std::ostream & out,const LayoutConstraint & constraint)59 std::ostream& operator<<(std::ostream& out,
60 const LayoutConstraint& constraint) {
61 out << constraint.ToString();
62 return out;
63 }
64
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)65 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
66 const LogicalBuffer& buffer,
67 bool mandatory, bool dfs)
68 : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
69 CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
70 }
71
ToString() const72 string BufferLayoutConstraint::ToString() const {
73 return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
74 LayoutUtil::HumanString(layout_));
75 }
76
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)77 OperandLayoutConstraint::OperandLayoutConstraint(
78 const ShapeLayout& shape_layout, const HloInstruction* instruction,
79 int64 operand_no, bool mandatory, bool dfs)
80 : LayoutConstraint(mandatory, dfs),
81 shape_layout_(shape_layout),
82 instruction_(instruction),
83 operand_no_(operand_no) {
84 CHECK(shape_layout_.LayoutIsSet());
85 CHECK(ShapeUtil::Compatible(shape_layout.shape(),
86 instruction->operand(operand_no)->shape()))
87 << shape_layout.shape() << " is not compatible with "
88 << instruction->operand(operand_no)->shape() << " (for operand "
89 << operand_no << " of instruction " << instruction->ToString() << ")";
90 }
91
ToString() const92 string OperandLayoutConstraint::ToString() const {
93 return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
94 instruction_->name(), operand_no_,
95 shape_layout_.ToString());
96 }
97
ToString() const98 string ResultLayoutConstraint::ToString() const {
99 return absl::StrFormat("ResultLayoutConstraint: %s",
100 shape_layout_.ToString());
101 }
102
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)103 LayoutConstraints::LayoutConstraints(
104 const TuplePointsToAnalysis& points_to_analysis,
105 HloComputation* computation)
106 : points_to_analysis_(points_to_analysis), computation_(computation) {
107 // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
108 for (HloInstruction* inst : computation_->instructions()) {
109 points_to_analysis_.GetPointsToSet(inst).ForEachElement(
110 [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
111 for (const LogicalBuffer* buffer : buffers) {
112 // The points to analysis is computed per module, restrict
113 // constraints to array buffers in this computation.
114 if (buffer->IsArray() &&
115 buffer->instruction()->parent() == computation) {
116 unconstrained_buffer_ids_.insert(buffer->id());
117 }
118 }
119 });
120 }
121 }
122
GetBufferSet(const HloInstruction * instruction) const123 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
124 const HloInstruction* instruction) const {
125 auto it = buffer_sets_cache_.find(instruction);
126 if (it != buffer_sets_cache_.end()) {
127 return it->second.get();
128 }
129 auto& buffer_set =
130 buffer_sets_cache_
131 .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
132 .first->second;
133 const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
134 points_to_set.ForEachElement(
135 [&buffer_set](const ShapeIndex& /*index*/,
136 const PointsToSet::BufferList& buffers) {
137 buffer_set->insert(buffers.begin(), buffers.end());
138 });
139 return buffer_set.get();
140 }
141
OperandBufferForwarded(const HloInstruction * instruction,int64 operand_no) const142 bool LayoutConstraints::OperandBufferForwarded(
143 const HloInstruction* instruction, int64 operand_no) const {
144 // The operand is potentially forwarded if the intersection of points-to sets
145 // of the operand and the instruction is non-empty.
146 PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
147 PointsToSet::BufferSet* operand_buffers =
148 GetBufferSet(instruction->operand(operand_no));
149 return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
150 return operand_buffers->count(b) > 0;
151 });
152 }
153
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)154 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
155 const LogicalBuffer& buffer,
156 bool mandatory, bool dfs) {
157 VLOG(3) << "SetBufferLayout : " << buffer << " : "
158 << LayoutUtil::HumanString(layout);
159
160 TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
161 if (!buffer.IsArray()) {
162 return FailedPrecondition(
163 "Layout of buffer %s cannot be constrained because buffer is not "
164 "array-shaped, has shape: %s",
165 buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
166 }
167 TF_RETURN_IF_ERROR(
168 LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
169
170 auto iter = buffer_constraints_.find(&buffer);
171 if (iter != buffer_constraints_.end()) {
172 const BufferLayoutConstraint& curr_constraint = iter->second;
173 if (LayoutUtil::Equal(curr_constraint.layout(), layout)) {
174 // New constraint matches existing constraint. Nothing to do.
175 return Status::OK();
176 }
177 if (curr_constraint.mandatory()) {
178 return FailedPrecondition(
179 "Buffer %s already has the layout constraint %s, cannot add "
180 "incompatible constraint %s",
181 buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
182 LayoutUtil::HumanString(layout));
183 }
184 iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
185 } else {
186 TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
187 << buffer.ToString();
188 iter = buffer_constraints_
189 .insert(std::make_pair(
190 &buffer,
191 BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
192 .first;
193 }
194 added_constraints_.push_back(&iter->second);
195 return Status::OK();
196 }
197
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)198 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
199 const HloInstruction* instruction,
200 int64 operand_no, bool mandatory,
201 bool dfs) {
202 VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
203 << operand_no << " : "
204 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
205
206 const OperandLayoutConstraint* curr_shape_layout =
207 GetOperandLayoutConstraint(instruction, operand_no);
208 if (curr_shape_layout != nullptr) {
209 if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
210 shape_with_layout)) {
211 // New constraint matches existing constraint. Nothing to do.
212 return Status::OK();
213 }
214 if (curr_shape_layout->mandatory()) {
215 return FailedPrecondition(
216 "Operand %d of instruction %s already has a layout constraint "
217 "%s, cannot add incompatible constraint %s",
218 operand_no, instruction->name(),
219 curr_shape_layout->shape_layout().ToString(),
220 ShapeUtil::HumanStringWithLayout(shape_with_layout));
221 }
222 }
223
224 // If any buffers in the operand occur in the output of the instruction, then
225 // return an error. This case is not handled because such a constraint changes
226 // layouts beyond this immediate use and is complicated to handle.
227 if (OperandBufferForwarded(instruction, operand_no)) {
228 return FailedPrecondition(
229 "Cannot constraint layout of operand %d of instruction %s "
230 "because instruction forwards operand's LogicalBuffer(s)",
231 operand_no, instruction->name());
232 }
233
234 auto key = std::make_pair(instruction, operand_no);
235 auto iter = operand_constraints_.find(key);
236 if (iter == operand_constraints_.end()) {
237 auto pair = std::make_pair(
238 key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
239 instruction, operand_no, mandatory, dfs));
240 iter = operand_constraints_.insert(pair).first;
241 } else {
242 iter->second =
243 OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
244 operand_no, mandatory, dfs);
245 }
246 added_constraints_.push_back(&iter->second);
247
248 return Status::OK();
249 }
250
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)251 Status LayoutConstraints::SetArrayOperandLayout(
252 const Layout& layout, const HloInstruction* instruction, int64 operand_no,
253 bool mandatory, bool dfs) {
254 const HloInstruction* operand = instruction->operand(operand_no);
255 TF_RET_CHECK(operand->shape().IsArray());
256 Shape shape(operand->shape());
257 *shape.mutable_layout() = layout;
258 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
259 return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
260 }
261
SetResultLayout(const Shape & shape_with_layout,bool dfs)262 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
263 bool dfs) {
264 VLOG(3) << "SetResultLayout : "
265 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
266
267 const ShapeLayout* curr_shape_layout = ResultLayout();
268 if (curr_shape_layout != nullptr) {
269 if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
270 return FailedPrecondition(
271 "Result of computation %s already has the layout constraint %s, "
272 "cannot add incompatible constraint %s",
273 computation_->name(), curr_shape_layout->ToString(),
274 ShapeUtil::HumanStringWithLayout(shape_with_layout));
275 }
276 // New constraint matches existing constraint. Nothing to do.
277 return Status::OK();
278 }
279
280 result_constraint_.reset(
281 new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
282 added_constraints_.push_back(result_constraint_.get());
283
284 return Status::OK();
285 }
286
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs)287 Status LayoutConstraints::SetInstructionLayout(
288 const Shape& shape_with_layout, const HloInstruction* instruction,
289 bool mandatory, bool dfs) {
290 VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
291 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
292
293 if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
294 return FailedPrecondition(
295 "Instruction %s of shape %s cannot be assigned incompatible layout %s",
296 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
297 ShapeUtil::HumanStringWithLayout(shape_with_layout));
298 }
299
300 // Create a BufferLayoutConstraint for each array shape in the output of the
301 // instruction.
302 return ShapeUtil::ForEachSubshapeWithStatus(
303 shape_with_layout,
304 [this, instruction, mandatory](const Shape& subshape,
305 const ShapeIndex& index) -> Status {
306 // The precondition for this method is that the instruction defines all
307 // buffers in its output.
308 auto buffers =
309 points_to_analysis_.GetPointsToSet(instruction).element(index);
310 CHECK_EQ(1, buffers.size());
311 CHECK_EQ(buffers[0]->instruction(), instruction);
312
313 if (subshape.IsArray()) {
314 return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
315 } else {
316 return Status::OK();
317 }
318 });
319 }
320
BufferLayout(const LogicalBuffer & buffer) const321 const Layout* LayoutConstraints::BufferLayout(
322 const LogicalBuffer& buffer) const {
323 if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
324 return &constraint->layout();
325 }
326 return nullptr;
327 }
328
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const329 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
330 const LogicalBuffer& buffer) const {
331 auto it = buffer_constraints_.find(&buffer);
332 return it == buffer_constraints_.end() ? nullptr : &it->second;
333 }
334
OperandLayout(const HloInstruction * instruction,int64 operand_no) const335 const ShapeLayout* LayoutConstraints::OperandLayout(
336 const HloInstruction* instruction, int64 operand_no) const {
337 if (const auto* constraint =
338 GetOperandLayoutConstraint(instruction, operand_no)) {
339 return &constraint->shape_layout();
340 }
341 return nullptr;
342 }
343
GetOperandLayoutConstraint(const HloInstruction * instruction,int64 operand_no) const344 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
345 const HloInstruction* instruction, int64 operand_no) const {
346 auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
347 return it == operand_constraints_.end() ? nullptr : &it->second;
348 }
349
ResultLayout() const350 const ShapeLayout* LayoutConstraints::ResultLayout() const {
351 return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
352 }
353
ToString() const354 string LayoutConstraints::ToString() const {
355 string output;
356 absl::StrAppend(&output, "LayoutConstraints for computation ",
357 computation_->name(), ":\n");
358 for (auto* instruction : computation_->MakeInstructionPostOrder()) {
359 absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
360 for (int64 i = 0; i < instruction->operand_count(); ++i) {
361 if (OperandLayout(instruction, i) != nullptr) {
362 absl::StrAppend(&output, " operand (", i,
363 "): ", OperandLayout(instruction, i)->ToString(), "\n");
364 }
365 }
366 for (const LogicalBuffer* buffer :
367 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
368 if (BufferLayout(*buffer) != nullptr) {
369 absl::StrAppend(&output, " ", buffer->ToString(), " : ",
370 LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
371 }
372 }
373 }
374
375 if (ResultLayout() != nullptr) {
376 absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
377 }
378 return output;
379 }
380
381 namespace {
382
IsHostSendRecv(const HloInstruction * instruction)383 bool IsHostSendRecv(const HloInstruction* instruction) {
384 const HloSendRecvInstruction* send_recv_instr =
385 DynCast<HloSendRecvInstruction>(instruction);
386 return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
387 }
388
389 } // namespace
390
BuildHostChannelConstraints(HloComputation * computation)391 Status LayoutAssignment::BuildHostChannelConstraints(
392 HloComputation* computation) {
393 for (auto* instruction : computation->instructions()) {
394 const HloSendRecvInstruction* send_recv_instr =
395 DynCast<HloSendRecvInstruction>(instruction);
396 if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
397 continue;
398 }
399
400 // For host transfers the Send and Recv instruction carry the layout.
401 if (instruction->opcode() == HloOpcode::kSend ||
402 instruction->opcode() == HloOpcode::kRecv) {
403 const Shape& data_shape =
404 ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
405 TF_RET_CHECK(data_shape.IsArray());
406 TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
407 const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
408 send_recv_instr->channel_id(), data_shape.layout());
409 TF_RET_CHECK(prev_layout == nullptr)
410 << "Cannot constrain host transfer layout as it was set to "
411 << LayoutUtil::HumanString(*prev_layout) << ": "
412 << send_recv_instr->ToString();
413 }
414 }
415 return Status::OK();
416 }
417
418 namespace {
419
IsLayoutConstrainedCustomCall(HloInstruction * instruction)420 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
421 const HloCustomCallInstruction* custom_call =
422 DynCast<HloCustomCallInstruction>(instruction);
423 return custom_call != nullptr && custom_call->layout_constrained();
424 }
425
426 } // namespace
427
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)428 Status LayoutAssignment::AddMandatoryConstraints(
429 const ComputationLayout* computation_layout,
430 ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
431 LayoutConstraints* constraints) {
432 VLOG(3) << "Adding mandatory layout constraints to computation "
433 << computation->name();
434
435 auto get_channel_constraints = [&](const HloInstruction* instruction) {
436 return IsHostSendRecv(instruction) ? &host_channel_constraints_
437 : channel_constraints;
438 };
439
440 // Constrain layouts of instructions which define values with pre-existing
441 // layouts.
442 for (auto* instruction : computation->instructions()) {
443 if (instruction->opcode() == HloOpcode::kInfeed) {
444 // Infeed layouts must match the layout of the original inserted
445 // instruction.
446 // TODO(b/31425034): Change infeeds to be more like parameters, with
447 // shapes in the ComputationLayout.
448 TF_RETURN_IF_ERROR(
449 constraints->SetInstructionLayout(instruction->shape(), instruction));
450 } else if (instruction->opcode() == HloOpcode::kOutfeed) {
451 // Constrain the input to the Outfeed instruction to be the expected
452 // layout of the Outfeed.
453 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
454 instruction->outfeed_shape(), instruction, 0));
455 } else if (instruction->opcode() == HloOpcode::kParameter) {
456 if (computation_layout != nullptr) {
457 const ShapeLayout& parameter_layout =
458 computation_layout->parameter_layout(
459 instruction->parameter_number());
460 if (parameter_layout.LayoutIsSet()) {
461 // Parameter layouts must match the respective layout in
462 // ComputationLayout, if there is one.
463 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
464 parameter_layout.shape(), instruction));
465 }
466 }
467 } else if (IsLayoutConstrainedCustomCall(instruction)) {
468 const HloCustomCallInstruction* custom_call =
469 DynCast<HloCustomCallInstruction>(instruction);
470 TF_RETURN_IF_ERROR(
471 constraints->SetInstructionLayout(custom_call->shape(), custom_call));
472 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
473 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
474 custom_call->operand_shapes_with_layout()[i], custom_call, i));
475 }
476 } else if (instruction->opcode() == HloOpcode::kSend ||
477 instruction->opcode() == HloOpcode::kRecv) {
478 CHECK(get_channel_constraints(instruction))
479 << "Multi-module layout assignment requires ChannelLayoutConstraints";
480 int64 channel_id = instruction->channel_id();
481 if (!get_channel_constraints(instruction)
482 ->IsChannelConstrained(channel_id)) {
483 continue;
484 }
485 if (instruction->opcode() == HloOpcode::kSend) {
486 // TODO(b/68493863): Change to use SetOperandLayout().
487 const Shape send_buffer_shape = instruction->operand(0)->shape();
488 TF_RET_CHECK(send_buffer_shape.IsArray());
489 Shape new_buffer_shape =
490 get_channel_constraints(instruction)
491 ->LayoutShapeForChannel(send_buffer_shape,
492 instruction->channel_id());
493 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
494 new_buffer_shape, instruction->operand(0)));
495 } else {
496 const Shape recv_buffer_shape =
497 ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
498 TF_RET_CHECK(recv_buffer_shape.IsArray());
499 TF_ASSIGN_OR_RETURN(
500 const LogicalBuffer* buffer,
501 constraints->points_to_analysis().GetBufferDefinedAt(instruction,
502 {0}));
503 Shape new_shape = get_channel_constraints(instruction)
504 ->LayoutShapeForChannel(
505 recv_buffer_shape, instruction->channel_id());
506 TF_RETURN_IF_ERROR(
507 constraints->SetBufferLayout(new_shape.layout(), *buffer));
508 }
509 } else if (instruction->IsCrossModuleAllReduce()) {
510 CHECK(get_channel_constraints(instruction))
511 << "Multi-module layout assignment requires ChannelLayoutConstraints";
512 int64 all_reduce_id = instruction->all_reduce_id().value();
513 if (!get_channel_constraints(instruction)
514 ->IsChannelConstrained(all_reduce_id)) {
515 continue;
516 }
517 // TODO(b/68493863): Change to use SetOperandLayout().
518 const Shape& buffer_shape = instruction->operand(0)->shape();
519 TF_RET_CHECK(buffer_shape.IsArray());
520 Shape new_buffer_shape =
521 get_channel_constraints(instruction)
522 ->LayoutShapeForChannel(buffer_shape, all_reduce_id);
523 TF_RETURN_IF_ERROR(
524 constraints->SetInstructionLayout(new_buffer_shape, instruction));
525 }
526 }
527
528 // Constrain layouts of instructions which call computations which have
529 // already been assigned layouts. Instructions which call computations in a
530 // parallel element-wise context (eg, map or reduce) do not need layout
531 // constraints because they operate on scalars.
532 for (auto* instruction : computation->instructions()) {
533 if (instruction->opcode() == HloOpcode::kCall) {
534 // kCall instruction operands and output must match the ComputationLayout
535 // of the called computation.
536 const ComputationLayout& called_computation_layout =
537 FindOrDie(computation_layouts_, instruction->to_apply());
538 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
539 called_computation_layout.result_layout().shape(), instruction));
540 TF_RET_CHECK(instruction->operand_count() ==
541 called_computation_layout.parameter_count());
542 for (int64 i = 0; i < instruction->operand_count(); ++i) {
543 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
544 called_computation_layout.parameter_layout(i).shape(), instruction,
545 i));
546 }
547 } else if (instruction->opcode() == HloOpcode::kWhile) {
548 // Layout of input and output of kWhile instruction must be equal and must
549 // match both input and output of body computation. Also, the input of
550 // condition computation must match kWhile layout.
551 HloComputation* body = instruction->while_body();
552 HloComputation* condition = instruction->while_condition();
553 const HloInstruction* init = instruction->operand(0);
554 ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
555 ComputationLayout& condition_layout =
556 FindOrDie(computation_layouts_, condition);
557
558 // Check a few invariants irrespective of layout.
559 CHECK_EQ(1, instruction->operand_count());
560 CHECK_EQ(1, body->num_parameters());
561 CHECK_EQ(1, condition->num_parameters());
562 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
563 body_layout.parameter_shape(0)));
564 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
565 condition_layout.parameter_shape(0)));
566 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
567
568 if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
569 VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
570 << " while=" << instruction->name()
571 << " shape=" << body_layout.result_layout().ToString();
572 *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
573 }
574 if (condition_layout.parameter_layout(0) !=
575 body_layout.parameter_layout(0)) {
576 VLOG(2) << "Reset %while condition parameter layout: cond="
577 << condition->name() << " while=" << instruction->name()
578 << " shape=" << body_layout.parameter_layout(0).ToString();
579 *condition_layout.mutable_parameter_layout(0) =
580 body_layout.parameter_layout(0);
581 }
582
583 // Constrain the output and the operand of the while instruction to match
584 // the computations.
585 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
586 body_layout.result_shape(), instruction));
587 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
588 body_layout.result_shape(), instruction, 0));
589 } else if (instruction->opcode() == HloOpcode::kConditional) {
590 // Find the conditional branch with the most instructions and force all
591 // other computations to match that layout. A potentially better decison
592 // could count the number FLOPs or how constrained the layouts are.
593 int64 largest_branch = 0;
594 int64 largest_instruction_count =
595 instruction->branch_computation(0)->instruction_count();
596 for (int j = 1; j < instruction->branch_count(); ++j) {
597 const int64 instruction_count =
598 instruction->branch_computation(j)->instruction_count();
599 if (instruction_count > largest_instruction_count) {
600 largest_branch = j;
601 largest_instruction_count = instruction_count;
602 }
603 }
604 ComputationLayout& best_branch_computation_layout =
605 FindOrDie(computation_layouts_,
606 instruction->branch_computation(largest_branch));
607 for (int k = 0; k < instruction->branch_count(); ++k) {
608 // Visit the best branch first.
609 int j = (k + largest_branch) % instruction->branch_count();
610 TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
611 ComputationLayout& branch_computation_layout =
612 FindOrDie(computation_layouts_, instruction->branch_computation(j));
613
614 DCHECK(ShapeUtil::Compatible(
615 instruction->operand(j + 1)->shape(),
616 branch_computation_layout.parameter_shape(0)));
617 if (best_branch_computation_layout.result_layout() !=
618 branch_computation_layout.result_layout()) {
619 // We assign layouts in DFS fashion, so the largest_branch and current
620 // branch computations might have negotiated a different layout. But
621 // for the case instruction POV the layout must match, so we run again
622 // on the branch j computation, this time with proper computation
623 // layout.
624 VLOG(2) << "Reset %conditional branch " << j
625 << " computation result layout: branch_computation="
626 << instruction->branch_computation(j)->name()
627 << " case=" << instruction->name() << " shape="
628 << best_branch_computation_layout.result_layout().ToString();
629 *branch_computation_layout.mutable_result_layout() =
630 best_branch_computation_layout.result_layout();
631 }
632 if (k == 0) {
633 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
634 best_branch_computation_layout.result_shape(), instruction));
635 }
636 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
637 branch_computation_layout.parameter_shape(0), instruction, j + 1,
638 /*mandatory=*/true));
639 }
640 }
641 }
642 // Finally set the result layout to match ComputationLayout, if there is one.
643 if (computation_layout != nullptr) {
644 const ShapeLayout& result_layout = computation_layout->result_layout();
645 if (result_layout.LayoutIsSet()) {
646 TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
647 }
648 }
649 return Status::OK();
650 }
651
652 namespace {
653
654 // The operands of a call must match the layouts of parameters in the
655 // ComputationLayout, and the call instruction itself must match the result
656 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)657 Status CheckCallLayout(HloInstruction* call,
658 const ComputationLayout& computation_layout) {
659 HloComputation* computation = call->to_apply();
660 TF_RET_CHECK(computation->num_parameters() == call->operand_count());
661 for (int64 i = 0; i < computation->num_parameters(); ++i) {
662 TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
663 call->operand(i)->shape()));
664 }
665 TF_RET_CHECK(
666 computation_layout.result_layout().MatchesLayoutInShape(call->shape()));
667 return Status::OK();
668 }
669
670 // Operands of layout-constrained custom calls must match the expected
671 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)672 Status CheckCustomCallLayout(HloInstruction* instruction) {
673 if (IsLayoutConstrainedCustomCall(instruction)) {
674 const HloCustomCallInstruction* custom_call =
675 DynCast<HloCustomCallInstruction>(instruction);
676 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
677 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
678 custom_call->operand(i)->shape(),
679 custom_call->operand_shapes_with_layout()[i]));
680 }
681 }
682 return Status::OK();
683 }
684
685 // For a while instruction, all the following layouts must be the same:
686 // (1) init operand
687 // (2) condition computation parameter
688 // (3) body computation parameter
689 // (4) body computation result
690 // (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)691 Status CheckWhileLayout(HloInstruction* while_inst,
692 const ComputationLayout& condition_computation_layout,
693 const ComputationLayout& body_computation_layout) {
694 auto init_shape = while_inst->operand(0)->shape();
695 TF_RET_CHECK(
696 condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
697 init_shape));
698 TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
699 init_shape));
700 TF_RET_CHECK(
701 body_computation_layout.result_layout().MatchesLayoutInShape(init_shape));
702 TF_RET_CHECK(
703 LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
704 return Status::OK();
705 }
706
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)707 Status CheckConditionalLayout(
708 HloInstruction* instruction,
709 absl::Span<const ComputationLayout> branch_computation_layouts) {
710 for (int j = 0; j < instruction->branch_count(); ++j) {
711 const HloInstruction* branch_operand = instruction->operand(j + 1);
712 TF_RET_CHECK(branch_computation_layouts[0].result_layout() ==
713 branch_computation_layouts[j].result_layout());
714 TF_RET_CHECK(
715 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
716 instruction->shape()));
717 TF_RET_CHECK(
718 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
719 instruction->branch_computation(j)->root_instruction()->shape()));
720 TF_RET_CHECK(
721 branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
722 branch_operand->shape()));
723 }
724 return Status::OK();
725 }
726
727 // Fusion parameters must match the layout of the fusion instructions operands,
728 // and the root of the fusion expression must match the layout of the fusion
729 // instruction.
CheckFusionLayout(HloInstruction * fusion)730 Status CheckFusionLayout(HloInstruction* fusion) {
731 TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
732
733 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
734 fusion->shape(), fusion->fused_expression_root()->shape()));
735 for (int64 i = 0; i < fusion->operand_count(); ++i) {
736 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
737 fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape()));
738 }
739 return Status::OK();
740 }
741
742 // The layout of a parameter must match the respective layout in the
743 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)744 Status CheckParameterLayout(HloInstruction* parameter,
745 const ComputationLayout& computation_layout) {
746 const ShapeLayout& parameter_layout =
747 computation_layout.parameter_layout(parameter->parameter_number());
748 if (parameter_layout.LayoutIsSet() &&
749 !parameter_layout.MatchesLayoutInShape(parameter->shape())) {
750 return InternalError(
751 "parameter instruction %s does not match layout of computation "
752 "shape: %s",
753 parameter->ToString(), parameter_layout.ToString());
754 }
755 return Status::OK();
756 }
757
758 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)759 Status CheckConstantLayout(HloInstruction* constant) {
760 if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(),
761 constant->shape())) {
762 return InternalError(
763 "constant instruction %s does not match the layout of its literal %s",
764 constant->ToString(),
765 ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
766 }
767 return Status::OK();
768 }
769
770 } // namespace
771
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)772 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
773 const Shape& shape_with_layout, HloInstruction* instruction) {
774 TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
775 DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
776 << ShapeUtil::HumanString(shape_with_layout) << " "
777 << ShapeUtil::HumanString(instruction->shape())
778 << " instruction: " << instruction->ToString();
779
780 if (instruction->shape().IsTuple()) {
781 // Copy tuple elements which have differing layouts.
782 std::vector<HloInstruction*> element_copies;
783 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
784 ++i) {
785 const Shape& target_shape =
786 ShapeUtil::GetSubshape(shape_with_layout, {i});
787 const Shape& instr_shape =
788 ShapeUtil::GetSubshape(instruction->shape(), {i});
789 HloInstruction* gte = instruction->parent()->AddInstruction(
790 HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
791
792 if (ShapeUtil::Equal(target_shape, instr_shape)) {
793 // Shapes and layouts are equal, no need to copy.
794 element_copies.push_back(gte);
795 } else {
796 SetupCopiedInstruction(*instruction, gte, {i});
797 // Recurse to copy each element.
798 TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
799 CreateCopyWithNewLayout(target_shape, gte));
800 element_copies.push_back(element_copy);
801 }
802 }
803 // Gather element copies into a tuple with a new Tuple instruction.
804 HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
805 HloInstruction::CreateTuple(element_copies));
806 SetupCopiedInstruction(*instruction, tuple_copy, {});
807 LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
808 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
809 shape_with_layout, tuple_copy->mutable_shape()));
810 return tuple_copy;
811 } else if (instruction->shape().IsArray()) {
812 HloInstruction* copy =
813 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
814 instruction->shape(), HloOpcode::kCopy, instruction));
815 RegisterAddedCopy(copy);
816 SetupCopiedInstruction(*instruction, copy, {});
817 LayoutUtil::ClearLayout(copy->mutable_shape());
818 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
819 shape_with_layout, copy->mutable_shape()));
820
821 return copy;
822 } else {
823 return FailedPrecondition(
824 "Can only copy array and tuple shaped instructions");
825 }
826 }
827
828 // Creates a copy of the given operand if the operand's layout does not match
829 // the given layout. This copy replaces the use in the given instruction. Tuple
830 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64 operand_no)831 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
832 const ShapeLayout& operand_layout, HloInstruction* instruction,
833 int64 operand_no) {
834 HloInstruction* operand = instruction->mutable_operand(operand_no);
835 TF_RET_CHECK(operand_layout.LayoutIsSet());
836 TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
837
838 if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
839 VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
840 << instruction->ToString();
841 // Operand layout already matches our constraint. Nothing to do.
842 return Status::OK();
843 }
844 VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
845 << operand_layout.ToString() << " in " << instruction->ToString();
846
847 TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
848 CreateCopyWithNewLayout(operand_layout.shape(), operand));
849
850 VLOG(4) << "New copy of " << operand->ToString() << " is "
851 << operand_copy->ToString();
852 return instruction->ReplaceOperandWith(operand_no, operand_copy);
853 }
854
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)855 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
856 HloInstruction* copy,
857 const ShapeIndex& index) {
858 if (instruction.has_sharding()) {
859 // If the index is empty, we want to copy the whole sharding, in case the
860 // sharding is a tuple sharding.
861 HloSharding sharding =
862 !index.empty() && instruction.sharding().IsTuple()
863 ? instruction.sharding().GetSubSharding(instruction.shape(), index)
864 : instruction.sharding();
865 // We propagate the sharding to the copied instruction only if it is a
866 // special sharding, like tiled ones.
867 // Otherwise it is preferable to leave the new instruction without device,
868 // and let the automatic device placer to choose the best location.
869 auto device = sharding.UniqueDevice();
870 if (!device || HloSharding::IsReservedDevice(*device)) {
871 copy->set_sharding(sharding);
872 }
873 }
874 copy->set_metadata(instruction.metadata());
875 }
876
CheckLayouts(HloModule * module)877 Status LayoutAssignment::CheckLayouts(HloModule* module) {
878 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
879 TuplePointsToAnalysis::Run(module));
880 for (auto* computation : module->MakeNonfusionComputations()) {
881 for (auto* instruction : computation->instructions()) {
882 // Verify every instruction has a layout and the layout is valid for the
883 // shape.
884 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
885 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
886
887 // Use points-to analysis to verify that every subshape element in the
888 // output of the instruction matches the layout of the logical buffer
889 // which could be the source of the subshape value.
890 const PointsToSet& points_to_set =
891 points_to_analysis->GetPointsToSet(instruction);
892 TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
893 [&instruction](ShapeIndex index,
894 const PointsToSet::BufferList& buffers) -> Status {
895 if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
896 const Shape& instruction_subshape =
897 ShapeUtil::GetSubshape(instruction->shape(), index);
898 for (const LogicalBuffer* buffer : buffers) {
899 if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) {
900 return InternalError(
901 "Layout of instruction %s at index {%s} does not match "
902 "source LogicalBuffer %s: %s vs %s",
903 instruction->name(), absl::StrJoin(index, ","),
904 buffer->ToString(),
905 ShapeUtil::HumanStringWithLayout(instruction_subshape),
906 ShapeUtil::HumanStringWithLayout(buffer->shape()));
907 }
908 }
909 }
910 return Status::OK();
911 }));
912
913 // Verify instructions that have special layout constraints.
914 switch (instruction->opcode()) {
915 case HloOpcode::kCall:
916 TF_RETURN_IF_ERROR(CheckCallLayout(
917 instruction,
918 FindOrDie(computation_layouts_, instruction->to_apply())));
919 break;
920 case HloOpcode::kCustomCall:
921 TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
922 break;
923 case HloOpcode::kFusion:
924 TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
925 break;
926 case HloOpcode::kParameter:
927 TF_RETURN_IF_ERROR(CheckParameterLayout(
928 instruction,
929 FindOrDie(computation_layouts_, instruction->parent())));
930 break;
931 case HloOpcode::kConstant:
932 TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
933 break;
934 case HloOpcode::kWhile:
935 TF_RETURN_IF_ERROR(CheckWhileLayout(
936 instruction,
937 FindOrDie(computation_layouts_, instruction->while_condition()),
938 FindOrDie(computation_layouts_, instruction->while_body())));
939 break;
940 case HloOpcode::kConditional: {
941 std::vector<ComputationLayout> branch_computation_layouts;
942 for (auto branch_computation : instruction->branch_computations()) {
943 branch_computation_layouts.emplace_back(
944 FindOrDie(computation_layouts_, branch_computation));
945 }
946 TF_RETURN_IF_ERROR(CheckConditionalLayout(
947 instruction, absl::MakeSpan(branch_computation_layouts)));
948 break;
949 }
950 default:
951 break;
952 }
953 }
954 }
955 // Finally verify the result layout, if set, matches the layout of the entry
956 // computation root.
957 const ShapeLayout& result_layout =
958 FindOrDie(computation_layouts_, module->entry_computation())
959 .result_layout();
960 if (result_layout.LayoutIsSet()) {
961 TF_RET_CHECK(
962 ShapeUtil::Equal(module->result_shape(), result_layout.shape()));
963 }
964 return Status::OK();
965 }
966
LayoutAssignment(ComputationLayout * entry_computation_layout,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func,ChannelLayoutConstraints * channel_constraints)967 LayoutAssignment::LayoutAssignment(
968 ComputationLayout* entry_computation_layout,
969 std::function<bool(const HloInstruction*)>
970 instruction_can_change_layout_func,
971 ChannelLayoutConstraints* channel_constraints)
972 : entry_computation_layout_(entry_computation_layout),
973
974 saved_entry_computation_layout_(*entry_computation_layout),
975 channel_layout_constraints_(channel_constraints),
976 instruction_can_change_layout_func_(
977 std::move(instruction_can_change_layout_func)) {
978 if (channel_layout_constraints_ != nullptr) {
979 // Save a copy of the input ChannelLayoutConstraints so that we can reset it
980 // if we have to undo previous operations (ClearPreviousPassSideEffects()).
981 channel_constraints_ = *channel_layout_constraints_;
982 }
983 VLOG(1) << "Entry computation layout given to layout assignment: "
984 << entry_computation_layout_->ToString();
985 }
986
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64 operand_no)987 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
988 const Layout& output_layout, const HloInstruction* instruction,
989 int64 operand_no) {
990 const HloInstruction* operand = instruction->operand(operand_no);
991 CHECK(instruction->shape().IsArray());
992 CHECK(operand->shape().IsArray());
993 if (!ShapeUtil::IsScalar(operand->shape()) &&
994 operand->shape().rank() == instruction->shape().rank() &&
995 !instruction_can_change_layout_func_(instruction)) {
996 // Propagate the result layout to the operand layout if the instruction
997 // requires the same layout out for the result and the operand.
998 //
999 // For elementwise operations, using the same layout for the operands and
1000 // the result also has the following benefits:
1001 // 1) the elementwise operation can reuse its operand's buffer, and
1002 // 2) the input and output elements can reuse the same linear index.
1003 return absl::make_unique<Layout>(output_layout);
1004 }
1005
1006 if (instruction->opcode() == HloOpcode::kReshape) {
1007 // Prefer the operand layout that makes the reshape an bitcast. If any
1008 // dimension bound is 1 in the operand shape, there may be several such
1009 // layouts. So if 'output_layout' is the default layout, try if the
1010 // reshape is a bitcast when using the same layout. This may avoid copy
1011 // operations. For similar reasons, if the operand and output have the same
1012 // rank, try to match the operand's layout to the output.
1013 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1014 ShapeUtil::TrueRank(instruction->shape()) == 1) {
1015 // Don't assign a layout in case of R1 -> effective R1 reshape.
1016 return nullptr;
1017 }
1018 const Shape& output_shape = instruction->shape();
1019 Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1020 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1021 LayoutUtil::MinorToMajor(output_layout));
1022 Shape operand_shape = operand->shape();
1023 *operand_shape.mutable_layout() =
1024 LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1025 auto aligned_operand_shape =
1026 ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1027 if (aligned_operand_shape) {
1028 auto operand_layout = aligned_operand_shape.value().layout();
1029 TF_CHECK_OK(
1030 LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1031 return absl::make_unique<Layout>(operand_layout);
1032 }
1033 }
1034
1035 if (instruction->opcode() == HloOpcode::kTranspose) {
1036 // Pick the operand layout that makes the transpose a bitcast.
1037 int64 rank = instruction->shape().rank();
1038 std::vector<int64> new_minor_to_major(rank);
1039 for (int64 i = 0; i < rank; ++i) {
1040 int64 output_dim = LayoutUtil::Minor(output_layout, i);
1041 int64 operand_dim = instruction->dimensions(output_dim);
1042 new_minor_to_major[i] = operand_dim;
1043 }
1044 Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1045 TF_CHECK_OK(
1046 LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1047 return absl::make_unique<Layout>(operand_layout);
1048 }
1049
1050 return nullptr;
1051 }
1052
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64 operand_no)1053 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1054 const Layout& operand_layout, const HloInstruction* user,
1055 int64 operand_no) {
1056 const HloInstruction* operand = user->operand(operand_no);
1057
1058 CHECK(user->shape().IsArray() && operand->shape().IsArray());
1059
1060 if (!ShapeUtil::IsScalar(operand->shape()) &&
1061 operand->shape().rank() == user->shape().rank() &&
1062 !instruction_can_change_layout_func_(user)) {
1063 // Assign users the same layout as the operand.
1064 return absl::make_unique<Layout>(operand_layout);
1065 }
1066
1067 if (user->opcode() == HloOpcode::kReshape) {
1068 // Prefer the user layout that makes the reshape an bitcast. If any
1069 // dimension bound is 1 in the user shape, there may be several such
1070 // layouts. So if 'operand_layout' is the default layout, try if the
1071 // reshape is a bitcast when using the same layout. This may avoid copy
1072 // operations. For similar reasons, if the operand and output have the same
1073 // rank, try to match the outputs's layout to the operand.
1074 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1075 ShapeUtil::TrueRank(user->shape()) == 1) {
1076 // Don't assign a layout in case of R1 -> effective R1 reshape.
1077 return nullptr;
1078 }
1079 Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1080 operand->shape().element_type(),
1081 AsInt64Slice(operand->shape().dimensions()),
1082 LayoutUtil::MinorToMajor(operand_layout));
1083 Shape output_shape = user->shape();
1084 *output_shape.mutable_layout() =
1085 LayoutUtil::GetDefaultLayoutForShape(output_shape);
1086 auto aligned_user_shape =
1087 ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1088 if (aligned_user_shape) {
1089 auto user_layout = aligned_user_shape.value().layout();
1090 TF_CHECK_OK(
1091 LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1092 return absl::make_unique<Layout>(user_layout);
1093 }
1094 }
1095
1096 if (user->opcode() == HloOpcode::kTranspose) {
1097 // Pick the user layout that makes the transpose a bitcast.
1098 int64 rank = user->shape().rank();
1099 std::vector<int64> new_minor_to_major(rank);
1100 auto inverse_dimensions = InversePermutation(user->dimensions());
1101 for (int64 i = 0; i < rank; ++i) {
1102 int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
1103 int64 user_dim = inverse_dimensions[operand_dim];
1104 new_minor_to_major[i] = user_dim;
1105 }
1106 Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1107 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1108 return absl::make_unique<Layout>(user_layout);
1109 }
1110
1111 return nullptr;
1112 }
1113
PropagateConstraints(LayoutConstraints * constraints)1114 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1115 // Gathers all initial constraints in a worklist and propagates them in
1116 // depth-first order. DFS order seems to be better than BFS because a
1117 // constraint is propagated as far as possible before propagating unrelated
1118 // constraints which makes it less likely that conflicting constraints will be
1119 // propagated to instructions. However, we should experiment with other orders
1120 // too.
1121 std::deque<const LayoutConstraint*> worklist;
1122
1123 // Lambda for moving newly added constraints to the worklist.
1124 auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1125 // Add constraints to the front of the deque for DFS ordering.
1126 for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1127 if (constraint->dfs()) {
1128 worklist.push_front(constraint);
1129 } else {
1130 worklist.push_back(constraint);
1131 }
1132 }
1133 };
1134 add_new_constraints_to_worklist();
1135
1136 while (!worklist.empty()) {
1137 const LayoutConstraint* layout_constraint = worklist.front();
1138 worklist.pop_front();
1139 VLOG(2) << "Propagating " << layout_constraint->ToString()
1140 << " to its neighbors.";
1141 if (auto* buffer_constraint =
1142 dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1143 TF_RETURN_IF_ERROR(
1144 PropagateBufferConstraint(*buffer_constraint, constraints));
1145 } else if (auto* operand_constraint =
1146 dynamic_cast<const OperandLayoutConstraint*>(
1147 layout_constraint)) {
1148 TF_RETURN_IF_ERROR(
1149 PropagateOperandConstraint(*operand_constraint, constraints));
1150 } else if (auto* result_constraint =
1151 dynamic_cast<const ResultLayoutConstraint*>(
1152 layout_constraint)) {
1153 TF_RETURN_IF_ERROR(
1154 PropagateResultConstraint(*result_constraint, constraints));
1155 } else {
1156 LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1157 }
1158
1159 add_new_constraints_to_worklist();
1160 }
1161 return Status::OK();
1162 }
1163
1164 namespace {
1165
1166 // Returns a vector containing all array-shaped uses (instruction and operand
1167 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1168 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1169 const LogicalBuffer& buffer,
1170 const TuplePointsToAnalysis& points_to_analysis) {
1171 CHECK(buffer.IsArray());
1172 std::vector<std::pair<const HloInstruction*, int64>> uses;
1173 for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1174 if (!buffer_alias.instruction()->shape().IsArray()) {
1175 continue;
1176 }
1177 // This alias must be the top-level (index == {}) of the instruction's
1178 // result because the instruction produces an array.
1179 CHECK(buffer_alias.index().empty());
1180
1181 // Add all uses of the instruction's output.
1182 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1183 for (int64 operand_no :
1184 user->OperandIndices(buffer_alias.instruction())) {
1185 uses.emplace_back(user, operand_no);
1186 }
1187 }
1188 }
1189 return uses;
1190 }
1191
1192 } // namespace
1193
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1194 Status LayoutAssignment::PropagateUseConstraintToDefs(
1195 const ShapeLayout& shape_layout, const HloInstruction* instruction,
1196 LayoutConstraints* constraints) {
1197 // Try to set all logical buffers which may be sources of the given operand to
1198 // match the given layout.
1199 const PointsToSet& points_to_set =
1200 constraints->points_to_analysis().GetPointsToSet(instruction);
1201 return points_to_set.ForEachElementWithStatus(
1202 [&shape_layout, constraints](
1203 const ShapeIndex& index,
1204 const PointsToSet::BufferList& buffers) -> Status {
1205 if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1206 for (const LogicalBuffer* buffer : buffers) {
1207 if (constraints->BufferLayout(*buffer) == nullptr &&
1208 buffer->shape().IsArray()) {
1209 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1210 ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1211 *buffer, /*mandatory=*/true));
1212 }
1213 }
1214 }
1215 return Status::OK();
1216 });
1217 }
1218
1219 namespace {
1220 // A transpose or a reshape that only changes trivial dimensions have meaningful
1221 // layouts that are valuable to propagate in a depthfirst manner to avoid
1222 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo)1223 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
1224 switch (hlo.opcode()) {
1225 case HloOpcode::kReshape:
1226 return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
1227 case HloOpcode::kTranspose:
1228 return true;
1229 default:
1230 return false;
1231 }
1232 }
1233
1234 } // namespace
1235
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1236 Status LayoutAssignment::PropagateOperandConstraint(
1237 const OperandLayoutConstraint& operand_constraint,
1238 LayoutConstraints* constraints) {
1239 // Try to set the layout of the logical buffers in the given operand to match
1240 // the constrained layout. This avoids copies.
1241 TF_RETURN_IF_ERROR(
1242 PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1243 operand_constraint.operand(), constraints));
1244
1245 // For array-shaped operands and user instructions try to pick a minimum cost
1246 // layout. For example, if the operand of an elementwise instruction is
1247 // constrained to a certain layout we want the output of the instruction to
1248 // have the same layout.
1249 //
1250 // If the user is not array-shaped, we still want to propagate the layout
1251 // to siblings if the instruction can't change layout. This is to represent
1252 // the information that non-layout-changing instructions should have the same
1253 // layout for the operands with the same ranks.
1254 const HloInstruction* operand = operand_constraint.operand();
1255 const HloInstruction* user = operand_constraint.instruction();
1256 if (!operand->shape().IsArray()) {
1257 return Status::OK();
1258 }
1259 if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
1260 return Status::OK();
1261 }
1262
1263 // Only try to choose a low cost layout if the instruction 'user' defines its
1264 // output (ie, doesn't forward a buffer from elsewhere).
1265 if (constraints->OperandBufferForwarded(user,
1266 operand_constraint.operand_no())) {
1267 return Status::OK();
1268 }
1269
1270 int64 operand_rank = operand->shape().rank();
1271 if (operand_rank <= 1) {
1272 return Status::OK();
1273 }
1274
1275 // Propagate layouts between operands of the same instruction. This is a
1276 // constraint on non-layout-changing instructions.
1277 if (!instruction_can_change_layout_func_(user)) {
1278 // Make sure all siblings have the same layout as the operand.
1279 for (int64 operand_no = 0; operand_no < user->operand_count();
1280 ++operand_no) {
1281 if (user->operand(operand_no) == operand) {
1282 continue;
1283 }
1284 const HloInstruction* sibling = user->operand(operand_no);
1285 const int64 sibling_rank = sibling->shape().rank();
1286 if (sibling_rank <= 1) {
1287 continue;
1288 }
1289 if (operand_rank != sibling_rank) {
1290 continue;
1291 }
1292 const OperandLayoutConstraint* constraint =
1293 constraints->GetOperandLayoutConstraint(user, operand_no);
1294 if (constraint != nullptr) {
1295 // Due to the DFS of the propagation we can end up here when operand_no
1296 // has a layout set that hasn't been propagated yet (is still on the
1297 // stack of layouts to propagate).
1298 // We can continue here and leave the operands with different layouts,
1299 // as we will either:
1300 // - overwrite the current operand when the DFS gets back to propagating
1301 // operand(operand_no) to its siblings
1302 // - overwrite operand(operand_no)'s layout with a mandatory layout if
1303 // we continue to propagate our layout to the result, and then
1304 // backwards into all operands (if the result is an array of rank > 1)
1305 continue;
1306 }
1307 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1308 operand_constraint.shape_layout().layout(), user, operand_no,
1309 /*mandatory=*/false));
1310 }
1311 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1312 user->shape(),
1313 [&](const Shape& subshape, const ShapeIndex& shape_index) {
1314 if (subshape.IsTuple()) {
1315 return Status::OK();
1316 }
1317 if (subshape.rank() <= 1) {
1318 return Status::OK();
1319 }
1320
1321 // Assign the right layout to input fusion of higher rank reduce
1322 // operations.
1323 if (subshape.rank() != operand->shape().rank()) {
1324 return Status::OK();
1325 }
1326 // TODO(b/67641796): Are there cases except fusion that use this code
1327 // path?
1328 TF_ASSIGN_OR_RETURN(
1329 const LogicalBuffer* buffer,
1330 constraints->points_to_analysis().GetBufferDefinedAt(
1331 user, shape_index));
1332 // Make sure the output has the same layout as the operand.
1333 const BufferLayoutConstraint* constraint =
1334 constraints->GetBufferLayoutConstraint(*buffer);
1335 // If we already have a constraint for the buffer it was assigned but
1336 // hasn't propagated yet. This can happen with diamond-shaped graphs
1337 // where one path is first evaluated in depth-first order (we're here)
1338 // and the other path is propagated later. We don't set the layout
1339 // here as it will always be overwritten later.
1340 if (constraint == nullptr) {
1341 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1342 operand_constraint.shape_layout().layout(), *buffer,
1343 /*mandatory=*/false));
1344 }
1345 return Status::OK();
1346 }));
1347 return Status::OK();
1348 }
1349 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1350 user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1351 if (subshape.IsTuple()) {
1352 return Status::OK();
1353 }
1354 if (subshape.rank() <= 1) {
1355 return Status::OK();
1356 }
1357 TF_ASSIGN_OR_RETURN(
1358 const LogicalBuffer* buffer,
1359 constraints->points_to_analysis().GetBufferDefinedAt(user,
1360 shape_index));
1361 if (constraints->BufferLayout(*buffer) == nullptr ||
1362 !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1363 std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1364 operand_constraint.shape_layout().layout(), user,
1365 operand_constraint.operand_no());
1366 if (layout != nullptr) {
1367 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1368 *layout, *buffer,
1369 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1370 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1371 }
1372 }
1373 return Status::OK();
1374 }));
1375 return Status::OK();
1376 }
1377
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1378 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1379 const BufferLayoutConstraint& buffer_constraint,
1380 LayoutConstraints* constraints) {
1381 VLOG(5) << "PropagateBufferConstraintToOperands: "
1382 << buffer_constraint.ToString();
1383 const LogicalBuffer& buffer = buffer_constraint.buffer();
1384
1385 const HloInstruction* instruction = buffer.instruction();
1386 if (IsAtMostRank1(instruction->shape())) {
1387 return Status::OK();
1388 }
1389
1390 for (int64 operand_no = 0; operand_no < instruction->operand_count();
1391 ++operand_no) {
1392 const HloInstruction* operand = instruction->operand(operand_no);
1393 if (IsAtMostRank1(operand->shape())) {
1394 continue;
1395 }
1396 if (!instruction_can_change_layout_func_(instruction)) {
1397 // Copy the layout to the operand.
1398 if (buffer.IsArray() && operand->shape().IsArray() &&
1399 operand->shape().rank() ==
1400 LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1401 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1402 buffer_constraint.layout(), instruction, operand_no,
1403 /*mandatory=*/true));
1404 }
1405 } else {
1406 if (!buffer.IsTopLevel() ||
1407 !instruction->operand(operand_no)->shape().IsArray()) {
1408 continue; // Don't touch buffers that are internal to a tuple.
1409 }
1410 VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1411 << instruction->ToShortString();
1412 // Assign a layout if there is no constraint already.
1413 const OperandLayoutConstraint* constraint =
1414 constraints->GetOperandLayoutConstraint(instruction, operand_no);
1415 if (constraint == nullptr || !constraint->mandatory()) {
1416 std::unique_ptr<Layout> operand_layout =
1417 ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1418 instruction, operand_no);
1419 if (operand_layout != nullptr) {
1420 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1421 *operand_layout, instruction, operand_no, /*mandatory=*/false,
1422 /*dfs=*/InstructionShouldPropagateDepthFirst(*instruction)));
1423 }
1424 } else {
1425 VLOG(6) << "Operand already has a constraint "
1426 << constraint->ToString();
1427 }
1428 }
1429 }
1430 return Status::OK();
1431 }
1432
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1433 Status LayoutAssignment::PropagateBufferConstraint(
1434 const BufferLayoutConstraint& buffer_constraint,
1435 LayoutConstraints* constraints) {
1436 // Only propagate array layouts.
1437 const LogicalBuffer& buffer = buffer_constraint.buffer();
1438 if (!buffer.IsArray()) {
1439 return Status::OK();
1440 }
1441 TF_RETURN_IF_ERROR(
1442 PropagateBufferConstraintToUses(buffer_constraint, constraints));
1443 return PropagateBufferConstraintToOperands(buffer_constraint, constraints);
1444 }
1445
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1446 Status LayoutAssignment::PropagateBufferConstraintToUses(
1447 const BufferLayoutConstraint& buffer_constraint,
1448 LayoutConstraints* constraints) {
1449 const LogicalBuffer& buffer = buffer_constraint.buffer();
1450 TF_RET_CHECK(buffer.IsArray());
1451
1452 // Propagate the layout to all array uses of the logical buffer. This skips
1453 // uses of the buffer where the buffer is the element of a tuple.
1454 for (const auto& user_operand_no :
1455 GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1456 const HloInstruction* user = user_operand_no.first;
1457 int64 operand_no = user_operand_no.second;
1458 // Only add an operand constraint if the user does not forward the buffer
1459 // because this case is not handled is SetOperandLayout.
1460 if (constraints->OperandLayout(user, operand_no) == nullptr &&
1461 !constraints->OperandBufferForwarded(user, operand_no)) {
1462 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1463 buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1464 }
1465 }
1466
1467 return Status::OK();
1468 }
1469
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1470 Status LayoutAssignment::PropagateResultConstraint(
1471 const ResultLayoutConstraint& layout_constraint,
1472 LayoutConstraints* constraints) {
1473 // Propagate the use constraint of the root instruction up to the logical
1474 // buffers which make up the result.
1475 return PropagateUseConstraintToDefs(
1476 layout_constraint.shape_layout(),
1477 constraints->computation()->root_instruction(), constraints);
1478 }
1479
1480 namespace {
1481
1482 // Infers the layout of the array at the given index in the given instruction's
1483 // output using points-to analysis. Precondition: The given instruction must
1484 // not produce this array value (that is, the array is forwarded from the
1485 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1486 StatusOr<Layout> InferArrayLayout(
1487 const TuplePointsToAnalysis& points_to_analysis,
1488 HloInstruction* instruction, const ShapeIndex& index) {
1489 // This function should only be called for array shapes which don't yet have
1490 // layouts.
1491 const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1492 TF_RET_CHECK(subshape.IsArray());
1493 TF_RET_CHECK(!subshape.has_layout());
1494
1495 // The instruction should not define the buffer at this index.
1496 TF_RET_CHECK(
1497 !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1498 << instruction->ToString();
1499
1500 const auto& source_buffers =
1501 points_to_analysis.GetPointsToSet(instruction).element(index);
1502 TF_RET_CHECK(!source_buffers.empty());
1503
1504 // Verify the layout is the same for every LogicalBuffer which this location
1505 // ('instruction' and 'index') points to.
1506 const Layout* first_buffer_layout = nullptr;
1507 for (const LogicalBuffer* source_buffer : source_buffers) {
1508 if (!source_buffer->shape().has_layout()) {
1509 // This should not happen because we've assigned layouts to all
1510 // instructions preceding this one.
1511 return InternalError("LogicalBuffer %s does not have a layout",
1512 source_buffer->ToString());
1513 }
1514
1515 if (first_buffer_layout == nullptr) {
1516 first_buffer_layout = &source_buffer->shape().layout();
1517 } else if (!LayoutUtil::Equal(source_buffer->shape().layout(),
1518 *first_buffer_layout)) {
1519 // The points-to set is ambiguous for this index and the different source
1520 // buffers have different layouts. This case is possible in valid XLA
1521 // computations because we do not propagate BufferLayoutConstraints to all
1522 // LogicalBuffers which may alias the constrained LogicalBuffer at some
1523 // point in the computation.
1524 return FailedPrecondition(
1525 "Array at index {%s} in instruction %s aliases buffers %s "
1526 "and %s which have different layouts",
1527 absl::StrJoin(index, ","), instruction->name(),
1528 source_buffers[0]->ToString(), source_buffer->ToString());
1529 }
1530 }
1531
1532 return *first_buffer_layout;
1533 }
1534
1535 // For fusion instructions, set the layout of each fused parameter instruction
1536 // to match the layout of its corresponding fusion instruction operand. Also,
1537 // set the layout of the fused root to match the layout of the fusion
1538 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1539 Status SetFusionLayouts(HloInstruction* fusion) {
1540 TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1541 for (auto* fused_instruction :
1542 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1543 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1544 const HloInstruction* fusion_operand =
1545 fusion->operand(fused_instruction->parameter_number());
1546 DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1547 fused_instruction->shape()));
1548 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1549 fusion_operand->shape(), fused_instruction->mutable_shape()));
1550 } else if (fused_instruction == fusion->fused_expression_root()) {
1551 // The layout of the root of the fused expression must match the fusion
1552 // instruction layout.
1553 DCHECK(
1554 ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1555 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1556 fusion->shape(), fused_instruction->mutable_shape()));
1557 } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1558 // A GTE inherits its layout from its operand (which should ultimately be
1559 // a parameter).
1560 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1561 fused_instruction->operand(0)->shape().tuple_shapes(
1562 fused_instruction->tuple_index()),
1563 fused_instruction->mutable_shape()));
1564 } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1565 // Give constants the layout of their literal.
1566 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1567 fused_instruction->literal().shape(),
1568 fused_instruction->mutable_shape()));
1569 } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1570 // Nop; leave the infeed layout alone.
1571 } else if (fusion->fusion_kind() != HloInstruction::FusionKind::kCustom) {
1572 // Other instructions don't have layouts inside of fusion nodes.
1573 // But do not clear layouts for other instructions in custom fusion nodes.
1574 LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1575 }
1576 }
1577
1578 return Status::OK();
1579 }
1580
1581 } // namespace
1582
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1583 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1584 HloComputation* computation) {
1585 VLOG(2) << "Assigning layouts to computation: " << computation->name();
1586 XLA_VLOG_LINES(2, computation->ToString());
1587 XLA_VLOG_LINES(2, constraints.ToString());
1588
1589 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1590 LayoutUtil::ClearLayout(instruction->mutable_shape());
1591
1592 // Create a copy of an operand if the operand instruction's layout does not
1593 // match the use constraint (OperandLayoutConstraint).
1594 for (int64 operand_no = 0; operand_no < instruction->operand_count();
1595 ++operand_no) {
1596 const ShapeLayout* operand_layout =
1597 constraints.OperandLayout(instruction, operand_no);
1598 if (operand_layout != nullptr) {
1599 TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1600 instruction, operand_no));
1601 }
1602 }
1603
1604 // Set the layouts of the array shapes this instruction defines as indicated
1605 // by the respective BufferLayoutConstraints. Any array shapes in the output
1606 // of the instruction which are not defined by the instruction (eg, array
1607 // elements in a Tuple instruction) will be assigned below via inference.
1608 for (const LogicalBuffer* buffer :
1609 constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1610 instruction)) {
1611 if (!buffer->shape().IsArray()) {
1612 continue;
1613 }
1614
1615 TF_RET_CHECK(buffer->instruction() == instruction);
1616 const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1617 TF_RET_CHECK(buffer_layout != nullptr);
1618
1619 if (instruction->opcode() == HloOpcode::kConstant) {
1620 // For constants, we also need to change the layout of the internal
1621 // literal.
1622 instruction->RelayoutConstant(*buffer_layout, buffer->index());
1623 } else {
1624 Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1625 instruction->mutable_shape(), buffer->index());
1626 *buffer_subshape->mutable_layout() = *buffer_layout;
1627 }
1628 }
1629
1630 // Any remaining layouts in the output of the instruction must be
1631 // inferrable using points-to analysis.
1632 TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1633 instruction->mutable_shape(),
1634 [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1635 if (subshape->has_layout() || !subshape->IsArray()) {
1636 return Status::OK();
1637 }
1638 // Set Layout of subshape to match layout of LogicalBuffer which
1639 // produces it.
1640 TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1641 InferArrayLayout(constraints.points_to_analysis(),
1642 instruction, index));
1643 return Status::OK();
1644 }));
1645
1646 // Fusion instructions require some layouts to be set on fused instructions
1647 // inside the fusion instruction.
1648 if (instruction->opcode() == HloOpcode::kFusion) {
1649 TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1650 }
1651
1652 // Execute extra verification step once the layout has been finalized.
1653 TF_RETURN_IF_ERROR(Verify(instruction));
1654
1655 // Shape must be valid.
1656 TF_RETURN_IF_ERROR(
1657 ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1658
1659 // Verify all layouts in the shape have been set.
1660 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1661 }
1662 return Status::OK();
1663 }
1664
CalculateComputationLayout(HloComputation * computation)1665 Status LayoutAssignment::CalculateComputationLayout(
1666 HloComputation* computation) {
1667 ComputationLayout computation_layout(computation->ComputeProgramShape(),
1668 /*ignore_layouts=*/false);
1669 InsertOrDie(&computation_layouts_, computation, computation_layout);
1670 VLOG(2) << " Calculated ComputationLayout = "
1671 << computation_layout.ToString();
1672 return Status::OK();
1673 }
1674
ClearComputationLayouts(HloComputation * computation)1675 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1676 // Clear existing layouts of the instructions. All layouts must be assigned
1677 // by the LayoutAssignment pass, except for those on parameters, the
1678 // computation result, and a couple special cases. The former two are
1679 // specified in computation_layout. Clearing the layouts here avoids hiding
1680 // potential bugs in the layout assignment pass that may accidentally use the
1681 // existing layout.
1682 for (HloInstruction* instruction : computation->instructions()) {
1683 if (instruction->opcode() == HloOpcode::kBitcast) {
1684 // bitcasts are inherently layout sensitive and so a bitcast instruction
1685 // present in the IR before layout assignment is a bug.
1686 return InternalError(
1687 "Unexpected bitcast operation seen during layout assignment: %s.",
1688 instruction->ToString());
1689 }
1690 // Some instructions carry mandatory layouts in their shape.
1691 if (instruction->opcode() != HloOpcode::kInfeed &&
1692 !IsLayoutConstrainedCustomCall(instruction)) {
1693 LayoutUtil::ClearLayout(instruction->mutable_shape());
1694 }
1695 }
1696 return Status::OK();
1697 }
1698
RunOnComputation(ComputationLayout * computation_layout,const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1699 Status LayoutAssignment::RunOnComputation(
1700 ComputationLayout* computation_layout,
1701 const TuplePointsToAnalysis& points_to_analysis,
1702 HloComputation* computation,
1703 ChannelLayoutConstraints* channel_constraints) {
1704 VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1705 << ")";
1706
1707 // Must be run before clearing layouts.
1708 TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1709
1710 TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1711 if (computation_layout != nullptr) {
1712 auto it = computation_layouts_.find(computation);
1713 if (it == computation_layouts_.end()) {
1714 VLOG(2) << " New ComputationLayout = " << computation_layout->ToString();
1715 computation_layouts_.emplace(computation, *computation_layout);
1716 } else {
1717 TF_RET_CHECK(computation_layout == &it->second ||
1718 computation_layout == entry_computation_layout_);
1719 VLOG(2) << " Existing ComputationLayout = "
1720 << computation_layout->ToString();
1721 }
1722 } else {
1723 VLOG(2) << " No ComputationLayout specified (will be calculated)";
1724 }
1725
1726 // Construct LayoutConstraints with all layout constraints of the computation.
1727 LayoutConstraints constraints(points_to_analysis, computation);
1728
1729 // Add constraints required for correctness on all backends (eg, entry
1730 // parameter layout constraints).
1731 TF_RETURN_IF_ERROR(AddMandatoryConstraints(
1732 computation_layout, channel_constraints, computation, &constraints));
1733
1734 // Add any backend-specific constraints.
1735 TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
1736
1737 // Propagates layouts from mandatory and backend constraints.
1738 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1739
1740 // Prior to applying default layouts, we take note of all HLO instructions
1741 // which lack a layout constraint.
1742 for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
1743 unconstrained_layout_instructions_.insert(
1744 points_to_analysis.GetBuffer(buffer_id).instruction());
1745 }
1746
1747 // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
1748 // layout and propagate the change.
1749 while (!constraints.unconstrained_buffer_ids().empty()) {
1750 int unconstrained_count = constraints.unconstrained_buffer_ids().size();
1751
1752 // Arbitrarily pick the first unconstrained buffer and give it the default
1753 // layout (or the literal layout, in case of constants). By construction
1754 // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
1755 const LogicalBuffer& buffer = points_to_analysis.GetBuffer(
1756 *constraints.unconstrained_buffer_ids().begin());
1757 const HloInstruction* instruction = buffer.instruction();
1758 Layout new_layout =
1759 instruction->opcode() == HloOpcode::kConstant
1760 ? ShapeUtil::GetSubshape(instruction->literal().shape(),
1761 buffer.index())
1762 .layout()
1763 : LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
1764 TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
1765 /*mandatory=*/false));
1766
1767 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1768
1769 // To verify progress has been made, check that the number of unconstrained
1770 // buffers has been reduced.
1771 CHECK_LT(constraints.unconstrained_buffer_ids().size(),
1772 unconstrained_count);
1773 }
1774 // All logical buffers should have constraints at this point. All that
1775 // remains is assign the constraints to the buffers and infer layouts for
1776 // aliased buffers.
1777 TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
1778
1779 // If the computation layout wasn't specified, now it is the time to compute
1780 // it according to the parameters and root instruction layouts.
1781 // This allows the first pass through this API to record the best flowing
1782 // layout to parameters and root instruction.
1783 if (computation_layout == nullptr) {
1784 TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
1785 }
1786
1787 // Record the layouts assigned for any communication ops in
1788 // channel_constraints so that they are constrained for future modules.
1789 if (channel_constraints != nullptr) {
1790 TF_RETURN_IF_ERROR(
1791 ConstrainChannelLayouts(computation, channel_constraints));
1792 }
1793
1794 // Copy the root instruction's result if its layout does not match the result
1795 // layout constraint.
1796 if (constraints.ResultLayout() != nullptr &&
1797 !constraints.ResultLayout()->MatchesLayoutInShape(
1798 computation->root_instruction()->shape())) {
1799 TF_ASSIGN_OR_RETURN(
1800 HloInstruction * new_root,
1801 CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
1802 computation->root_instruction()));
1803 computation->set_root_instruction(new_root);
1804 }
1805 return Status::OK();
1806 }
1807
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1808 Status LayoutAssignment::ConstrainChannelLayouts(
1809 HloComputation* computation,
1810 ChannelLayoutConstraints* channel_constraints) {
1811 auto get_channel_constraints = [&](const HloInstruction* instruction) {
1812 return IsHostSendRecv(instruction) ? &host_channel_constraints_
1813 : channel_constraints;
1814 };
1815 // We go through the kRecvDone before. These must either impose their layout,
1816 // or find a matching one already existing (ConstrainChannel() returns
1817 // nullptr).
1818 for (HloInstruction* instruction : computation->instructions()) {
1819 if (instruction->opcode() == HloOpcode::kRecvDone) {
1820 const Layout* layout =
1821 get_channel_constraints(instruction)
1822 ->ConstrainChannel(
1823 instruction->channel_id(),
1824 ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
1825 TF_RET_CHECK(layout == nullptr)
1826 << instruction->ToString()
1827 << " cannot constrain layout as it was set to "
1828 << LayoutUtil::HumanString(*layout);
1829 }
1830 }
1831 // After that we go through the kSend. These are likely going to have a kCopy
1832 // as operand (otherwise we add it), so in case the constrained layout does
1833 // not match, we can change the kCopy layout (and the kSend one as well).
1834 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1835 if (instruction->opcode() == HloOpcode::kSend) {
1836 HloInstruction* operand = instruction->mutable_operand(0);
1837 const Layout* layout = get_channel_constraints(instruction)
1838 ->ConstrainChannel(instruction->channel_id(),
1839 operand->shape().layout());
1840 if (layout != nullptr) {
1841 // We found an already constrained layout which does not match the one
1842 // the kSend wants to impose. Either add a new kCopy, or use the
1843 // existing one to marshal the correct shape.
1844 Shape shape = operand->shape();
1845 *shape.mutable_layout() = *layout;
1846 if (operand->opcode() != HloOpcode::kCopy) {
1847 HloInstruction* copy = operand->parent()->AddInstruction(
1848 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
1849 RegisterAddedCopy(copy);
1850 SetupCopiedInstruction(*operand, copy, {});
1851 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
1852 operand = copy;
1853 } else {
1854 *operand->mutable_shape() = shape;
1855 }
1856 Shape* send_shape =
1857 ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
1858 *send_shape = shape;
1859 }
1860 } else if (instruction->IsCrossModuleAllReduce()) {
1861 const Layout* layout =
1862 get_channel_constraints(instruction)
1863 ->ConstrainChannel(instruction->all_reduce_id().value(),
1864 instruction->shape().layout());
1865 if (layout != nullptr) {
1866 // We found an already constrained layout which does not match the one
1867 // the channel wants to impose. Either add a new kCopy, or use the
1868 // existing one to marshal the correct shape.
1869 HloInstruction* operand = instruction->mutable_operand(0);
1870 Shape shape = operand->shape();
1871 *shape.mutable_layout() = *layout;
1872 if (operand->opcode() != HloOpcode::kCopy) {
1873 HloInstruction* copy = operand->parent()->AddInstruction(
1874 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
1875 RegisterAddedCopy(copy);
1876 SetupCopiedInstruction(*operand, copy, {});
1877 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
1878 operand = copy;
1879 } else {
1880 *operand->mutable_shape() = shape;
1881 }
1882 *instruction->mutable_shape() = shape;
1883 }
1884 }
1885 }
1886 return Status::OK();
1887 }
1888
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)1889 Status LayoutAssignment::PropagateComputationLayouts(
1890 HloComputation* computation, ComputationLayout* computation_layout) {
1891 ComputationLayout computed_computation_layout(
1892 computation->ComputeProgramShape(),
1893 /*ignore_layouts=*/false);
1894 for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
1895 ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
1896 if (!param_layout->LayoutIsSet()) {
1897 VLOG(4) << "Assigning layout to parameter " << i << " of computation "
1898 << computation->name() << ": "
1899 << computed_computation_layout.parameter_layout(i).ToString();
1900 *param_layout = computed_computation_layout.parameter_layout(i);
1901 } else {
1902 TF_RET_CHECK(computed_computation_layout.parameter_layout(i) ==
1903 *param_layout);
1904 }
1905 }
1906 ShapeLayout* result_layout = computation_layout->mutable_result_layout();
1907 if (!result_layout->LayoutIsSet()) {
1908 VLOG(4) << "Assigning result layout of computation " << computation->name()
1909 << ": " << computed_computation_layout.result_layout().ToString();
1910 *result_layout = computed_computation_layout.result_layout();
1911 } else {
1912 TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout);
1913 }
1914 return Status::OK();
1915 }
1916
Run(HloModule * module)1917 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
1918 VLOG(2) << "Running layout assignment on module " << module->name();
1919 TF_RETURN_IF_ERROR(Init());
1920
1921 // Verify computation layout is sane.
1922 const HloComputation* entry = module->entry_computation();
1923 TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
1924 entry->num_parameters());
1925 for (int64 i = 0; i < entry->num_parameters(); ++i) {
1926 TF_RET_CHECK(
1927 ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
1928 entry->parameter_instruction(i)->shape()));
1929 }
1930 TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
1931 entry->root_instruction()->shape()));
1932
1933 // We do two passes. The first one we pass a nullptr ComputationLayout to
1934 // the RunOnComputation() calls (for non entry computations), and we register
1935 // the ComputationLayout which are naturally flowing in DFS fashion to the
1936 // parameters and root instruction.
1937 // Walking in DFS mode though, means that we can end up with incorrect layouts
1938 // when seen from an outer instruction, which has across-computation
1939 // constraints to impose.
1940 // For example, the kWhile instruction needs to enforce the same layouts for
1941 // the parameters and root of the body, as well as the condition parameters.
1942 // Similarly, the kConditional instruction needs to enforce the same layouts
1943 // for the root of the true and false computations.
1944 // So in the first pass, while allowing the layouts to flow to parameters and
1945 // root, we also fix up the eventually inconsistent ComputationLayout, which
1946 // will be then made mandatory by the second pass.
1947 for (int64 i = 0; i < 2; ++i) {
1948 VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
1949 TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
1950 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
1951 TuplePointsToAnalysis::Run(module));
1952 for (auto* computation : module->MakeComputationPostOrder()) {
1953 if (computation->IsFusionComputation()) {
1954 continue;
1955 }
1956 if (computation == module->entry_computation()) {
1957 TF_RETURN_IF_ERROR(RunOnComputation(
1958 entry_computation_layout_, *points_to_analysis,
1959 module->entry_computation(), channel_layout_constraints_));
1960 } else {
1961 ComputationLayout* computation_layout =
1962 (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation);
1963 TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
1964 *points_to_analysis, computation,
1965 channel_layout_constraints_));
1966 }
1967 }
1968 }
1969 TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
1970 entry_computation_layout_));
1971 TF_RETURN_IF_ERROR(CheckLayouts(module));
1972
1973 // All layouts are reset then reassigned by this pass.
1974 return true;
1975 }
1976
1977 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)1978 bool LayoutAssignment::InstructionCanChangeLayout(
1979 const HloInstruction* instruction) {
1980 switch (instruction->opcode()) {
1981 case HloOpcode::kAbs:
1982 case HloOpcode::kAdd:
1983 case HloOpcode::kAddDependency:
1984 case HloOpcode::kAnd:
1985 case HloOpcode::kAtan2:
1986 case HloOpcode::kBitcastConvert:
1987 case HloOpcode::kCeil:
1988 case HloOpcode::kClamp:
1989 case HloOpcode::kClz:
1990 case HloOpcode::kCompare:
1991 case HloOpcode::kComplex:
1992 case HloOpcode::kConcatenate:
1993 case HloOpcode::kConditional:
1994 case HloOpcode::kConvert:
1995 case HloOpcode::kCos:
1996 case HloOpcode::kAllReduce:
1997 case HloOpcode::kAllToAll:
1998 case HloOpcode::kCollectivePermute:
1999 case HloOpcode::kDivide:
2000 case HloOpcode::kDynamicSlice:
2001 case HloOpcode::kDynamicUpdateSlice:
2002 case HloOpcode::kExp:
2003 case HloOpcode::kExpm1:
2004 case HloOpcode::kFft:
2005 case HloOpcode::kFloor:
2006 case HloOpcode::kImag:
2007 case HloOpcode::kIsFinite:
2008 case HloOpcode::kLog:
2009 case HloOpcode::kLog1p:
2010 case HloOpcode::kMap:
2011 case HloOpcode::kMaximum:
2012 case HloOpcode::kMinimum:
2013 case HloOpcode::kMultiply:
2014 case HloOpcode::kNegate:
2015 case HloOpcode::kNot:
2016 case HloOpcode::kOr:
2017 case HloOpcode::kXor:
2018 case HloOpcode::kPad:
2019 case HloOpcode::kPower:
2020 case HloOpcode::kReal:
2021 case HloOpcode::kReducePrecision:
2022 case HloOpcode::kReduceWindow:
2023 case HloOpcode::kRemainder:
2024 case HloOpcode::kReverse:
2025 case HloOpcode::kRoundNearestAfz:
2026 case HloOpcode::kRsqrt:
2027 case HloOpcode::kScatter:
2028 case HloOpcode::kSelect:
2029 case HloOpcode::kSelectAndScatter:
2030 case HloOpcode::kShiftLeft:
2031 case HloOpcode::kShiftRightArithmetic:
2032 case HloOpcode::kShiftRightLogical:
2033 case HloOpcode::kSign:
2034 case HloOpcode::kSin:
2035 case HloOpcode::kSlice:
2036 case HloOpcode::kSort:
2037 case HloOpcode::kSqrt:
2038 case HloOpcode::kSubtract:
2039 case HloOpcode::kTanh:
2040 case HloOpcode::kTriangularSolve:
2041 case HloOpcode::kCholesky:
2042 case HloOpcode::kTupleSelect:
2043 case HloOpcode::kWhile:
2044 return false;
2045 case HloOpcode::kBatchNormGrad:
2046 case HloOpcode::kBatchNormInference:
2047 case HloOpcode::kBatchNormTraining:
2048 case HloOpcode::kBitcast:
2049 case HloOpcode::kBroadcast:
2050 case HloOpcode::kCall:
2051 case HloOpcode::kConstant:
2052 case HloOpcode::kConvolution:
2053 case HloOpcode::kCopy:
2054 case HloOpcode::kCustomCall:
2055 case HloOpcode::kDomain:
2056 case HloOpcode::kDot:
2057 case HloOpcode::kFusion:
2058 case HloOpcode::kGather:
2059 case HloOpcode::kGetTupleElement:
2060 case HloOpcode::kInfeed:
2061 case HloOpcode::kIota:
2062 case HloOpcode::kOutfeed:
2063 case HloOpcode::kParameter:
2064 case HloOpcode::kRecv:
2065 case HloOpcode::kRecvDone:
2066 case HloOpcode::kReduce:
2067 case HloOpcode::kReplicaId:
2068 case HloOpcode::kReshape:
2069 case HloOpcode::kRng:
2070 case HloOpcode::kSend:
2071 case HloOpcode::kSendDone:
2072 case HloOpcode::kAfterAll:
2073 case HloOpcode::kTrace:
2074 case HloOpcode::kTranspose:
2075 case HloOpcode::kTuple:
2076 case HloOpcode::kGetDimensionSize:
2077 return true;
2078 }
2079 }
2080
2081 /* static */
IsAtMostRank1(const Shape & shape)2082 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2083 if (shape.IsArray()) {
2084 return shape.rank() <= 1;
2085 }
2086 return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2087 return IsAtMostRank1(subshape);
2088 });
2089 }
2090
Init()2091 Status LayoutAssignment::Init() {
2092 computation_layouts_.clear();
2093 *entry_computation_layout_ = saved_entry_computation_layout_;
2094 return Status::OK();
2095 }
2096
ClearPreviousPassSideEffects(HloModule * module)2097 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2098 VLOG(5) << "Clearing previous side effects";
2099 // Clear all the copies which have been added, and all the related
2100 // instructions (like GTE and tuples).
2101 int64 removed_copies = 0;
2102 for (HloComputation* computation : module->computations()) {
2103 for (HloInstruction* instruction :
2104 computation->MakeInstructionPostOrder()) {
2105 if (instruction->opcode() == HloOpcode::kCopy &&
2106 added_copies_.contains(instruction)) {
2107 VLOG(5) << "Removing added copy: " << instruction->ToString();
2108 TF_RETURN_IF_ERROR(
2109 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2110 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2111 ++removed_copies;
2112 }
2113 }
2114 }
2115 added_copies_.clear();
2116 unconstrained_layout_instructions_.clear();
2117 if (removed_copies > 0) {
2118 TupleSimplifier tuple_simplifier;
2119 HloDCE dce;
2120 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2121 TF_RETURN_IF_ERROR(dce.Run(module).status());
2122 }
2123 ResetChannelConstraints();
2124 return Status::OK();
2125 }
2126
AddCopyForOperand(HloInstruction * instruction,int64 operand_number)2127 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2128 int64 operand_number) {
2129 HloInstruction* operand = instruction->mutable_operand(operand_number);
2130 if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2131 HloInstruction* copy =
2132 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2133 operand->shape(), HloOpcode::kCopy, operand));
2134 SetupCopiedInstruction(*operand, copy, {});
2135 LayoutUtil::ClearLayout(copy->mutable_shape());
2136 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2137 }
2138 return Status::OK();
2139 }
2140
2141 } // namespace xla
2142