1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28
29 #include "absl/algorithm/container.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/service/call_graph.h"
38 #include "tensorflow/compiler/xla/service/computation_layout.h"
39 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
40 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
41 #include "tensorflow/compiler/xla/service/hlo_computation.h"
42 #include "tensorflow/compiler/xla/service/hlo_dce.h"
43 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
44 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
45 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
46 #include "tensorflow/compiler/xla/service/logical_buffer.h"
47 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
48 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
49 #include "tensorflow/compiler/xla/shape_layout.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/status_macros.h"
52 #include "tensorflow/compiler/xla/statusor.h"
53 #include "tensorflow/compiler/xla/types.h"
54 #include "tensorflow/compiler/xla/util.h"
55 #include "tensorflow/compiler/xla/xla_data.pb.h"
56 #include "tensorflow/core/lib/core/errors.h"
57 #include "tensorflow/core/lib/core/status.h"
58 #include "tensorflow/core/platform/logging.h"
59 #include "tensorflow/core/platform/protobuf.h"
60
61 namespace xla {
62
operator <<(std::ostream & out,const LayoutConstraint & constraint)63 std::ostream& operator<<(std::ostream& out,
64 const LayoutConstraint& constraint) {
65 out << constraint.ToString();
66 return out;
67 }
68
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)69 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
70 const LogicalBuffer& buffer,
71 bool mandatory, bool dfs)
72 : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) {
73 CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
74 }
75
ToString() const76 string BufferLayoutConstraint::ToString() const {
77 return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
78 LayoutUtil::HumanString(layout_));
79 }
80
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)81 OperandLayoutConstraint::OperandLayoutConstraint(
82 const ShapeLayout& shape_layout, const HloInstruction* instruction,
83 int64 operand_no, bool mandatory, bool dfs)
84 : LayoutConstraint(mandatory, dfs),
85 shape_layout_(shape_layout),
86 instruction_(instruction),
87 operand_no_(operand_no) {
88 CHECK(shape_layout_.LayoutIsSet());
89 CHECK(ShapeUtil::Compatible(shape_layout.shape(),
90 instruction->operand(operand_no)->shape()))
91 << shape_layout.shape() << " is not compatible with "
92 << instruction->operand(operand_no)->shape() << " (for operand "
93 << operand_no << " of instruction " << instruction->ToString() << ")";
94 }
95
ToString() const96 string OperandLayoutConstraint::ToString() const {
97 return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
98 instruction_->name(), operand_no_,
99 shape_layout_.ToString());
100 }
101
ToString() const102 string ResultLayoutConstraint::ToString() const {
103 return absl::StrFormat("ResultLayoutConstraint: %s",
104 shape_layout_.ToString());
105 }
106
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)107 LayoutConstraints::LayoutConstraints(
108 const TuplePointsToAnalysis& points_to_analysis,
109 HloComputation* computation)
110 : points_to_analysis_(points_to_analysis), computation_(computation) {
111 // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
112 for (HloInstruction* inst : computation_->instructions()) {
113 points_to_analysis_.GetPointsToSet(inst).ForEachElement(
114 [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
115 for (const LogicalBuffer* buffer : buffers) {
116 // The points to analysis is computed per module, restrict
117 // constraints to array buffers in this computation.
118 if (buffer->IsArray() &&
119 buffer->instruction()->parent() == computation) {
120 unconstrained_buffer_ids_.insert(buffer->id());
121 }
122 }
123 });
124 }
125 }
126
GetBufferSet(const HloInstruction * instruction) const127 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
128 const HloInstruction* instruction) const {
129 auto it = buffer_sets_cache_.find(instruction);
130 if (it != buffer_sets_cache_.end()) {
131 return it->second.get();
132 }
133 auto& buffer_set =
134 buffer_sets_cache_
135 .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
136 .first->second;
137 const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
138 points_to_set.ForEachElement(
139 [&buffer_set](const ShapeIndex& /*index*/,
140 const PointsToSet::BufferList& buffers) {
141 buffer_set->insert(buffers.begin(), buffers.end());
142 });
143 return buffer_set.get();
144 }
145
OperandBufferForwarded(const HloInstruction * instruction,int64 operand_no) const146 bool LayoutConstraints::OperandBufferForwarded(
147 const HloInstruction* instruction, int64 operand_no) const {
148 // The operand is potentially forwarded if the intersection of points-to sets
149 // of the operand and the instruction is non-empty.
150 PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
151 PointsToSet::BufferSet* operand_buffers =
152 GetBufferSet(instruction->operand(operand_no));
153 return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
154 return operand_buffers->count(b) > 0;
155 });
156 }
157
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)158 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
159 const LogicalBuffer& buffer,
160 bool mandatory, bool dfs) {
161 VLOG(3) << "SetBufferLayout : " << buffer << " : "
162 << LayoutUtil::HumanString(layout);
163
164 TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
165 if (!buffer.IsArray()) {
166 return FailedPrecondition(
167 "Layout of buffer %s cannot be constrained because buffer is not "
168 "array-shaped, has shape: %s",
169 buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
170 }
171 TF_RETURN_IF_ERROR(
172 LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
173
174 auto iter = buffer_constraints_.find(&buffer);
175 if (iter != buffer_constraints_.end()) {
176 const BufferLayoutConstraint& curr_constraint = iter->second;
177 if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
178 // New constraint matches existing constraint. Nothing to do.
179 return Status::OK();
180 }
181 if (curr_constraint.mandatory()) {
182 if (!mandatory) {
183 VLOG(3) << "Buffer" << buffer
184 << " already has a mandatory layout constrain, skipping";
185 return Status::OK();
186 }
187 return FailedPrecondition(
188 "Buffer %s already has the layout constraint %s, cannot add "
189 "incompatible constraint %s",
190 buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
191 LayoutUtil::HumanString(layout));
192 }
193 iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
194 } else {
195 TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
196 << buffer.ToString();
197 iter = buffer_constraints_
198 .insert(std::make_pair(
199 &buffer,
200 BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
201 .first;
202 }
203 added_constraints_.push_back(&iter->second);
204 return Status::OK();
205 }
206
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)207 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
208 const HloInstruction* instruction,
209 int64 operand_no, bool mandatory,
210 bool dfs) {
211 VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
212 << operand_no << " : "
213 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
214
215 const OperandLayoutConstraint* curr_shape_layout =
216 GetOperandLayoutConstraint(instruction, operand_no);
217 if (curr_shape_layout != nullptr) {
218 if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
219 shape_with_layout, /*minor_to_major_only=*/true)) {
220 // New constraint matches existing constraint. Nothing to do.
221 return Status::OK();
222 }
223 if (curr_shape_layout->mandatory()) {
224 return FailedPrecondition(
225 "Operand %d of instruction %s already has a layout constraint "
226 "%s, cannot add incompatible constraint %s",
227 operand_no, instruction->name(),
228 curr_shape_layout->shape_layout().ToString(),
229 ShapeUtil::HumanStringWithLayout(shape_with_layout));
230 }
231 }
232
233 // If any buffers in the operand occur in the output of the instruction, then
234 // return an error. This case is not handled because such a constraint changes
235 // layouts beyond this immediate use and is complicated to handle.
236 if (OperandBufferForwarded(instruction, operand_no)) {
237 return FailedPrecondition(
238 "Cannot constraint layout of operand %d of instruction %s "
239 "because instruction forwards operand's LogicalBuffer(s)",
240 operand_no, instruction->name());
241 }
242
243 auto key = std::make_pair(instruction, operand_no);
244 auto iter = operand_constraints_.find(key);
245 if (iter == operand_constraints_.end()) {
246 auto pair = std::make_pair(
247 key, OperandLayoutConstraint(ShapeLayout(shape_with_layout),
248 instruction, operand_no, mandatory, dfs));
249 iter = operand_constraints_.insert(pair).first;
250 } else {
251 iter->second =
252 OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
253 operand_no, mandatory, dfs);
254 }
255 added_constraints_.push_back(&iter->second);
256
257 return Status::OK();
258 }
259
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64 operand_no,bool mandatory,bool dfs)260 Status LayoutConstraints::SetArrayOperandLayout(
261 const Layout& layout, const HloInstruction* instruction, int64 operand_no,
262 bool mandatory, bool dfs) {
263 const HloInstruction* operand = instruction->operand(operand_no);
264 TF_RET_CHECK(operand->shape().IsArray());
265 Shape shape(operand->shape());
266 *shape.mutable_layout() = layout;
267 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
268 return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
269 }
270
SetResultLayout(const Shape & shape_with_layout,bool dfs)271 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
272 bool dfs) {
273 VLOG(3) << "SetResultLayout : "
274 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
275
276 const ShapeLayout* curr_shape_layout = ResultLayout();
277 if (curr_shape_layout != nullptr) {
278 if (!curr_shape_layout->MatchesLayoutInShape(
279 shape_with_layout, /*minor_to_major_only=*/true)) {
280 return FailedPrecondition(
281 "Result of computation %s already has the layout constraint %s, "
282 "cannot add incompatible constraint %s",
283 computation_->name(), curr_shape_layout->ToString(),
284 ShapeUtil::HumanStringWithLayout(shape_with_layout));
285 }
286 // New constraint matches existing constraint. Nothing to do.
287 return Status::OK();
288 }
289 result_constraint_.reset(
290 new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
291 added_constraints_.push_back(result_constraint_.get());
292
293 return Status::OK();
294 }
295
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs)296 Status LayoutConstraints::SetInstructionLayout(
297 const Shape& shape_with_layout, const HloInstruction* instruction,
298 bool mandatory, bool dfs) {
299 VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
300 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
301
302 if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
303 return FailedPrecondition(
304 "Instruction %s of shape %s cannot be assigned incompatible layout %s",
305 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
306 ShapeUtil::HumanStringWithLayout(shape_with_layout));
307 }
308
309 // Create a BufferLayoutConstraint for each array shape in the output of the
310 // instruction.
311 return ShapeUtil::ForEachSubshapeWithStatus(
312 shape_with_layout,
313 [this, instruction, mandatory](const Shape& subshape,
314 const ShapeIndex& index) -> Status {
315 // The precondition for this method is that the instruction defines all
316 // buffers in its output.
317 auto buffers =
318 points_to_analysis_.GetPointsToSet(instruction).element(index);
319 CHECK_EQ(1, buffers.size());
320 CHECK_EQ(buffers[0]->instruction(), instruction);
321
322 if (subshape.IsArray() && subshape.has_layout()) {
323 return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
324 } else {
325 return Status::OK();
326 }
327 });
328 }
329
BufferLayout(const LogicalBuffer & buffer) const330 const Layout* LayoutConstraints::BufferLayout(
331 const LogicalBuffer& buffer) const {
332 if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
333 return &constraint->layout();
334 }
335 return nullptr;
336 }
337
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const338 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
339 const LogicalBuffer& buffer) const {
340 auto it = buffer_constraints_.find(&buffer);
341 return it == buffer_constraints_.end() ? nullptr : &it->second;
342 }
343
OperandLayout(const HloInstruction * instruction,int64 operand_no) const344 const ShapeLayout* LayoutConstraints::OperandLayout(
345 const HloInstruction* instruction, int64 operand_no) const {
346 if (const auto* constraint =
347 GetOperandLayoutConstraint(instruction, operand_no)) {
348 return &constraint->shape_layout();
349 }
350 return nullptr;
351 }
352
GetOperandLayoutConstraint(const HloInstruction * instruction,int64 operand_no) const353 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
354 const HloInstruction* instruction, int64 operand_no) const {
355 auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
356 return it == operand_constraints_.end() ? nullptr : &it->second;
357 }
358
ResultLayout() const359 const ShapeLayout* LayoutConstraints::ResultLayout() const {
360 return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
361 }
362
ToString() const363 string LayoutConstraints::ToString() const {
364 string output;
365 absl::StrAppend(&output, "LayoutConstraints for computation ",
366 computation_->name(), ":\n");
367 for (auto* instruction : computation_->MakeInstructionPostOrder()) {
368 absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
369 for (int64 i = 0; i < instruction->operand_count(); ++i) {
370 if (OperandLayout(instruction, i) != nullptr) {
371 absl::StrAppend(&output, " operand (", i,
372 "): ", OperandLayout(instruction, i)->ToString(), "\n");
373 }
374 }
375 for (const LogicalBuffer* buffer :
376 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
377 if (BufferLayout(*buffer) != nullptr) {
378 absl::StrAppend(&output, " ", buffer->ToString(), " : ",
379 LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
380 }
381 }
382 }
383
384 if (ResultLayout() != nullptr) {
385 absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
386 }
387 return output;
388 }
389
390 namespace {
391
IsHostSendRecv(const HloInstruction * instruction)392 bool IsHostSendRecv(const HloInstruction* instruction) {
393 const HloSendRecvInstruction* send_recv_instr =
394 DynCast<HloSendRecvInstruction>(instruction);
395 return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
396 }
397
398 } // namespace
399
BuildHostChannelConstraints(HloComputation * computation)400 Status LayoutAssignment::BuildHostChannelConstraints(
401 HloComputation* computation) {
402 for (auto* instruction : computation->instructions()) {
403 const HloSendRecvInstruction* send_recv_instr =
404 DynCast<HloSendRecvInstruction>(instruction);
405 if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
406 continue;
407 }
408
409 // For host transfers the Send and Recv instruction carry the layout.
410 if (instruction->opcode() == HloOpcode::kSend ||
411 instruction->opcode() == HloOpcode::kRecv) {
412 const Shape& data_shape =
413 ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
414 TF_RET_CHECK(data_shape.IsArray());
415 TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
416 const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
417 *send_recv_instr->channel_id(), data_shape.layout());
418 TF_RET_CHECK(prev_layout == nullptr)
419 << "Cannot constrain host transfer layout as it was set to "
420 << LayoutUtil::HumanString(*prev_layout) << ": "
421 << send_recv_instr->ToString();
422 }
423 }
424 return Status::OK();
425 }
426
427 namespace {
428
IsLayoutConstrainedCustomCall(HloInstruction * instruction)429 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
430 const HloCustomCallInstruction* custom_call =
431 DynCast<HloCustomCallInstruction>(instruction);
432 return custom_call != nullptr && custom_call->layout_constrained();
433 }
434
IsLayoutConstrainedAllReduce(const HloInstruction * instruction)435 bool IsLayoutConstrainedAllReduce(const HloInstruction* instruction) {
436 const HloAllReduceInstruction* all_reduce =
437 DynCast<HloAllReduceInstruction>(instruction);
438 return all_reduce != nullptr && all_reduce->constrain_layout();
439 }
440
441 } // namespace
442
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)443 Status LayoutAssignment::AddMandatoryConstraints(
444 const ComputationLayout* computation_layout,
445 ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
446 LayoutConstraints* constraints) {
447 VLOG(3) << "Adding mandatory layout constraints to computation "
448 << computation->name();
449
450 auto get_channel_constraints = [&](const HloInstruction* instruction) {
451 return IsHostSendRecv(instruction) ? &host_channel_constraints_
452 : channel_constraints;
453 };
454
455 // Constrain layouts of instructions which define values with pre-existing
456 // layouts.
457 for (auto* instruction : computation->instructions()) {
458 if (instruction->opcode() == HloOpcode::kInfeed) {
459 // Infeed layouts must match the layout of the original inserted
460 // instruction.
461 // TODO(b/31425034): Change infeeds to be more like parameters, with
462 // shapes in the ComputationLayout.
463 TF_RETURN_IF_ERROR(
464 constraints->SetInstructionLayout(instruction->shape(), instruction));
465 } else if (instruction->opcode() == HloOpcode::kOutfeed) {
466 // Constrain the input to the Outfeed instruction to be the expected
467 // layout of the Outfeed.
468 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
469 instruction->outfeed_shape(), instruction, 0));
470 } else if (instruction->opcode() == HloOpcode::kParameter) {
471 if (computation_layout != nullptr) {
472 const ShapeLayout& parameter_layout =
473 computation_layout->parameter_layout(
474 instruction->parameter_number());
475 // Parameter layouts must match the respective layout in
476 // ComputationLayout, if there is one.
477 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
478 parameter_layout.shape(), instruction));
479 }
480 } else if (IsLayoutConstrainedCustomCall(instruction)) {
481 const HloCustomCallInstruction* custom_call =
482 DynCast<HloCustomCallInstruction>(instruction);
483 TF_RETURN_IF_ERROR(
484 constraints->SetInstructionLayout(custom_call->shape(), custom_call));
485 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
486 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
487 custom_call->operand_shapes_with_layout()[i], custom_call, i));
488 }
489 } else if (instruction->opcode() == HloOpcode::kSend ||
490 instruction->opcode() == HloOpcode::kRecv) {
491 CHECK(get_channel_constraints(instruction))
492 << "Multi-module layout assignment requires ChannelLayoutConstraints";
493 int64 channel_id = *instruction->channel_id();
494 if (!get_channel_constraints(instruction)
495 ->IsChannelConstrained(channel_id)) {
496 continue;
497 }
498 if (instruction->opcode() == HloOpcode::kSend) {
499 // TODO(b/68493863): Change to use SetOperandLayout().
500 const Shape send_buffer_shape = instruction->operand(0)->shape();
501 TF_RET_CHECK(send_buffer_shape.IsArray());
502 Shape new_buffer_shape =
503 get_channel_constraints(instruction)
504 ->LayoutShapeForChannel(send_buffer_shape,
505 *instruction->channel_id());
506 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
507 new_buffer_shape, instruction->operand(0)));
508 } else {
509 const Shape recv_buffer_shape =
510 ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
511 TF_RET_CHECK(recv_buffer_shape.IsArray());
512 TF_ASSIGN_OR_RETURN(
513 const LogicalBuffer* buffer,
514 constraints->points_to_analysis().GetBufferDefinedAt(instruction,
515 {0}));
516 Shape new_shape =
517 get_channel_constraints(instruction)
518 ->LayoutShapeForChannel(recv_buffer_shape,
519 *instruction->channel_id());
520 TF_RETURN_IF_ERROR(
521 constraints->SetBufferLayout(new_shape.layout(), *buffer));
522 }
523 } else if (IsLayoutConstrainedAllReduce(instruction)) {
524 TF_RETURN_IF_ERROR(
525 constraints->SetInstructionLayout(instruction->shape(), instruction));
526 } else if (instruction->IsCrossModuleAllReduce()) {
527 CHECK(get_channel_constraints(instruction))
528 << "Multi-module layout assignment requires ChannelLayoutConstraints";
529 int64 channel_id = instruction->channel_id().value();
530 if (!get_channel_constraints(instruction)
531 ->IsChannelConstrained(channel_id)) {
532 continue;
533 }
534 // TODO(b/68493863): Change to use SetOperandLayout().
535 const Shape& buffer_shape = instruction->operand(0)->shape();
536 TF_RET_CHECK(buffer_shape.IsArray());
537 Shape new_buffer_shape =
538 get_channel_constraints(instruction)
539 ->LayoutShapeForChannel(buffer_shape, channel_id);
540 TF_RETURN_IF_ERROR(
541 constraints->SetInstructionLayout(new_buffer_shape, instruction));
542 }
543 }
544
545 // Constrain layouts of instructions which call computations which have
546 // already been assigned layouts. Instructions which call computations in a
547 // parallel element-wise context (eg, map or reduce) do not need layout
548 // constraints because they operate on scalars.
549 for (auto* instruction : computation->instructions()) {
550 if (instruction->opcode() == HloOpcode::kCall) {
551 // kCall instruction operands and output must match the ComputationLayout
552 // of the called computation.
553 const ComputationLayout& called_computation_layout =
554 FindOrDie(computation_layouts_, instruction->to_apply());
555 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
556 called_computation_layout.result_layout().shape(), instruction));
557 TF_RET_CHECK(instruction->operand_count() ==
558 called_computation_layout.parameter_count());
559 for (int64 i = 0; i < instruction->operand_count(); ++i) {
560 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
561 called_computation_layout.parameter_layout(i).shape(), instruction,
562 i));
563 }
564 } else if (instruction->opcode() == HloOpcode::kWhile) {
565 // Layout of input and output of kWhile instruction must be equal and must
566 // match both input and output of body computation. Also, the input of
567 // condition computation must match kWhile layout.
568 HloComputation* body = instruction->while_body();
569 HloComputation* condition = instruction->while_condition();
570 const HloInstruction* init = instruction->operand(0);
571 ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
572 ComputationLayout& condition_layout =
573 FindOrDie(computation_layouts_, condition);
574
575 // Check a few invariants irrespective of layout.
576 CHECK_EQ(1, instruction->operand_count());
577 CHECK_EQ(1, body->num_parameters());
578 CHECK_EQ(1, condition->num_parameters());
579 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
580 body_layout.parameter_shape(0)));
581 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
582 condition_layout.parameter_shape(0)));
583 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
584
585 if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
586 VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
587 << " while=" << instruction->name()
588 << " shape=" << body_layout.result_layout().ToString();
589 *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
590 }
591 if (condition_layout.parameter_layout(0) !=
592 body_layout.parameter_layout(0)) {
593 VLOG(2) << "Reset %while condition parameter layout: cond="
594 << condition->name() << " while=" << instruction->name()
595 << " shape=" << body_layout.parameter_layout(0).ToString();
596 *condition_layout.mutable_parameter_layout(0) =
597 body_layout.parameter_layout(0);
598 }
599
600 // Constrain the output and the operand of the while instruction to match
601 // the computations.
602 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
603 body_layout.result_shape(), instruction, 0));
604 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
605 body_layout.result_shape(), instruction));
606 } else if (instruction->opcode() == HloOpcode::kConditional) {
607 // Find the conditional branch with the most instructions and force all
608 // other computations to match that layout. A potentially better decision
609 // could count the number FLOPs or how constrained the layouts are.
610 int64 largest_branch = 0;
611 int64 largest_instruction_count =
612 instruction->branch_computation(0)->instruction_count();
613 for (int j = 1; j < instruction->branch_count(); ++j) {
614 const int64 instruction_count =
615 instruction->branch_computation(j)->instruction_count();
616 if (instruction_count > largest_instruction_count) {
617 largest_branch = j;
618 largest_instruction_count = instruction_count;
619 }
620 }
621 ComputationLayout& best_branch_computation_layout =
622 FindOrDie(computation_layouts_,
623 instruction->branch_computation(largest_branch));
624 for (int k = 0; k < instruction->branch_count(); ++k) {
625 // Visit the best branch first.
626 int j = (k + largest_branch) % instruction->branch_count();
627 TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
628 ComputationLayout& branch_computation_layout =
629 FindOrDie(computation_layouts_, instruction->branch_computation(k));
630 if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
631 best_branch_computation_layout.result_layout().shape(),
632 /*minor_to_major_only=*/true)) {
633 computation_layouts_.erase(instruction->branch_computation(k));
634 InsertOrDie(&conditional_mismatch_,
635 instruction->branch_computation(k),
636 best_branch_computation_layout);
637 } else {
638 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
639 branch_computation_layout.parameter_shape(0), instruction, k + 1,
640 /*mandatory=*/true));
641 }
642 }
643 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
644 best_branch_computation_layout.parameter_shape(0), instruction,
645 largest_branch + 1,
646 /*mandatory=*/true));
647 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
648 best_branch_computation_layout.result_shape(), instruction));
649 }
650 }
651 // Finally set the result layout to match ComputationLayout, if there is one.
652 if (conditional_mismatch_.count(computation) > 0) {
653 TF_RETURN_IF_ERROR(constraints->SetResultLayout(
654 FindOrDie(conditional_mismatch_, computation).result_layout().shape()));
655 } else if (computation_layout != nullptr) {
656 const ShapeLayout& result_layout = computation_layout->result_layout();
657 if (result_layout.LayoutIsSet()) {
658 TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
659 }
660 }
661 return Status::OK();
662 }
663
664 namespace {
665
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)666 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
667 return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
668 }
669
670 // The operands of a call must match the layouts of parameters in the
671 // ComputationLayout, and the call instruction itself must match the result
672 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)673 Status CheckCallLayout(HloInstruction* call,
674 const ComputationLayout& computation_layout) {
675 HloComputation* computation = call->to_apply();
676 TF_RET_CHECK(computation->num_parameters() == call->operand_count());
677 for (int64 i = 0; i < computation->num_parameters(); ++i) {
678 TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
679 call->operand(i)->shape(), /*minor_to_major_only=*/true));
680 }
681 TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
682 call->shape(), /*minor_to_major_only=*/true));
683 return Status::OK();
684 }
685
686 // Operands of layout-constrained custom calls must match the expected
687 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)688 Status CheckCustomCallLayout(HloInstruction* instruction) {
689 if (IsLayoutConstrainedCustomCall(instruction)) {
690 const HloCustomCallInstruction* custom_call =
691 DynCast<HloCustomCallInstruction>(instruction);
692 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
693 TF_RET_CHECK(
694 LayoutsInShapesEqual(custom_call->operand(i)->shape(),
695 custom_call->operand_shapes_with_layout()[i]));
696 }
697 }
698 return Status::OK();
699 }
700
701 // For a while instruction, all the following layouts must be the same:
702 // (1) init operand
703 // (2) condition computation parameter
704 // (3) body computation parameter
705 // (4) body computation result
706 // (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)707 Status CheckWhileLayout(HloInstruction* while_inst,
708 const ComputationLayout& condition_computation_layout,
709 const ComputationLayout& body_computation_layout) {
710 auto init_shape = while_inst->operand(0)->shape();
711 TF_RET_CHECK(
712 condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
713 init_shape, /*minor_to_major_only=*/true));
714 TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
715 init_shape, /*minor_to_major_only=*/true));
716 TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
717 init_shape, /*minor_to_major_only=*/true));
718 TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
719 return Status::OK();
720 }
721
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)722 Status CheckConditionalLayout(
723 HloInstruction* instruction,
724 absl::Span<const ComputationLayout> branch_computation_layouts) {
725 for (int j = 0; j < instruction->branch_count(); ++j) {
726 const HloInstruction* branch_operand = instruction->operand(j + 1);
727 TF_RET_CHECK(
728 branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
729 branch_computation_layouts[j].result_layout().shape(),
730 /*minor_to_major_only=*/true));
731 TF_RET_CHECK(
732 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
733 instruction->shape(), /*minor_to_major_only=*/true));
734 TF_RET_CHECK(
735 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
736 instruction->branch_computation(j)->root_instruction()->shape(),
737 /*minor_to_major_only=*/true));
738 TF_RET_CHECK(
739 branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
740 branch_operand->shape(), /*minor_to_major_only=*/true));
741 }
742 return Status::OK();
743 }
744
745 // Fusion parameters must match the layout of the fusion instructions operands,
746 // and the root of the fusion expression must match the layout of the fusion
747 // instruction.
CheckFusionLayout(HloInstruction * fusion)748 Status CheckFusionLayout(HloInstruction* fusion) {
749 TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
750
751 TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
752 fusion->fused_expression_root()->shape()));
753 for (int64 i = 0; i < fusion->operand_count(); ++i) {
754 TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
755 fusion->operand(i)->shape()));
756 }
757 return Status::OK();
758 }
759
760 // The layout of a parameter must match the respective layout in the
761 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)762 Status CheckParameterLayout(HloInstruction* parameter,
763 const ComputationLayout& computation_layout) {
764 const ShapeLayout& parameter_layout =
765 computation_layout.parameter_layout(parameter->parameter_number());
766 return ShapeUtil::ForEachSubshapeWithStatus(
767 parameter_layout.shape(),
768 [&](const Shape& subshape, const ShapeIndex& shape_index) {
769 if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
770 !subshape.has_layout()) {
771 return Status::OK();
772 }
773 if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
774 subshape,
775 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
776 return InternalError(
777 "parameter instruction %s does not match layout of computation "
778 "shape: %s",
779 parameter->ToString(), parameter_layout.ToString());
780 }
781 return Status::OK();
782 });
783 }
784
785 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)786 Status CheckConstantLayout(HloInstruction* constant) {
787 if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
788 return InternalError(
789 "constant instruction %s does not match the layout of its literal %s",
790 constant->ToString(),
791 ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
792 }
793 return Status::OK();
794 }
795
796 } // namespace
797
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)798 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
799 const Shape& shape_with_layout, HloInstruction* instruction) {
800 TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
801 DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
802 << ShapeUtil::HumanString(shape_with_layout) << " "
803 << ShapeUtil::HumanString(instruction->shape())
804 << " instruction: " << instruction->ToString();
805
806 if (instruction->shape().IsTuple()) {
807 // Copy tuple elements which have differing layouts.
808 std::vector<HloInstruction*> element_copies;
809 for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
810 ++i) {
811 const Shape& target_shape =
812 ShapeUtil::GetSubshape(shape_with_layout, {i});
813 const Shape& instr_shape =
814 ShapeUtil::GetSubshape(instruction->shape(), {i});
815 HloInstruction* gte = instruction->parent()->AddInstruction(
816 HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
817
818 if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
819 instr_shape)) {
820 // Shapes and layouts are equal, no need to copy.
821 element_copies.push_back(gte);
822 } else {
823 SetupCopiedInstruction(*instruction, gte, {i});
824 // Recurse to copy each element.
825 TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
826 CreateCopyWithNewLayout(target_shape, gte));
827 element_copies.push_back(element_copy);
828 }
829 }
830 // Gather element copies into a tuple with a new Tuple instruction.
831 HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
832 HloInstruction::CreateTuple(element_copies));
833 SetupCopiedInstruction(*instruction, tuple_copy, {});
834 LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
835 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
836 shape_with_layout, tuple_copy->mutable_shape()));
837 return tuple_copy;
838 } else if (instruction->shape().IsArray()) {
839 HloInstruction* copy =
840 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
841 instruction->shape(), HloOpcode::kCopy, instruction));
842 RegisterAddedCopy(copy);
843 SetupCopiedInstruction(*instruction, copy, {});
844 LayoutUtil::ClearLayout(copy->mutable_shape());
845 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
846 shape_with_layout, copy->mutable_shape()));
847
848 return copy;
849 } else {
850 return FailedPrecondition(
851 "Can only copy array and tuple shaped instructions");
852 }
853 }
854
855 // Creates a copy of the given operand if the operand's layout does not match
856 // the given layout. This copy replaces the use in the given instruction. Tuple
857 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64 operand_no)858 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
859 const ShapeLayout& operand_layout, HloInstruction* instruction,
860 int64 operand_no) {
861 HloInstruction* operand = instruction->mutable_operand(operand_no);
862 TF_RET_CHECK(operand_layout.LayoutIsSet());
863 TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
864
865 if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
866 operand->shape())) {
867 VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
868 << instruction->ToString();
869 // Operand layout already matches our constraint. Nothing to do.
870 return Status::OK();
871 }
872 VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
873 << operand_layout.ToString() << " in " << instruction->ToString();
874
875 // If the operand is only used by a conditional, do the copy inside the branch
876 // to avoid overhead for other branches.
877 if (instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
878 instruction->operand(operand_no)->user_count() == 1) {
879 auto branch_comp = instruction->branch_computation(operand_no - 1);
880 auto param = branch_comp->parameter_instruction(0);
881 *param->mutable_shape() = operand->shape();
882 auto param_users = param->users();
883 TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
884 CreateCopyWithNewLayout(operand_layout.shape(), param));
885 for (auto user : param_users) {
886 TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
887 }
888 VLOG(4) << "New copy of " << operand->ToString() << " is "
889 << param_copy->ToString();
890 if (param == branch_comp->root_instruction()) {
891 branch_comp->set_root_instruction(param_copy,
892 /*accept_different_shape=*/true);
893 }
894 *FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) =
895 ShapeLayout(operand->shape());
896 return Status::OK();
897 }
898
899 TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
900 CreateCopyWithNewLayout(operand_layout.shape(), operand));
901
902 VLOG(4) << "New copy of " << operand->ToString() << " is "
903 << operand_copy->ToString();
904 return instruction->ReplaceOperandWith(operand_no, operand_copy);
905 }
906
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)907 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
908 HloInstruction* copy,
909 const ShapeIndex& index) {
910 if (instruction.has_sharding()) {
911 // If the index is empty, we want to copy the whole sharding, in case the
912 // sharding is a tuple sharding.
913 HloSharding sharding =
914 !index.empty() && instruction.sharding().IsTuple()
915 ? instruction.sharding().GetSubSharding(instruction.shape(), index)
916 : instruction.sharding();
917 // We propagate the sharding to the copied instruction only if it is a
918 // special sharding, like tiled ones.
919 // Otherwise it is preferable to leave the new instruction without device,
920 // and let the automatic device placer to choose the best location.
921 auto device = sharding.UniqueDevice();
922 if (!device || HloSharding::IsReservedDevice(*device)) {
923 copy->set_sharding(sharding);
924 }
925 }
926 copy->set_metadata(instruction.metadata());
927 }
928
CheckLayouts(HloModule * module)929 Status LayoutAssignment::CheckLayouts(HloModule* module) {
930 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
931 TuplePointsToAnalysis::Run(module));
932 for (auto* computation : module->MakeNonfusionComputations()) {
933 for (auto* instruction : computation->instructions()) {
934 // Verify every instruction has a layout and the layout is valid for the
935 // shape.
936 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
937 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
938
939 // Use points-to analysis to verify that every subshape element in the
940 // output of the instruction matches the layout of the logical buffer
941 // which could be the source of the subshape value.
942 const PointsToSet& points_to_set =
943 points_to_analysis->GetPointsToSet(instruction);
944 TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
945 [&instruction](ShapeIndex index,
946 const PointsToSet::BufferList& buffers) -> Status {
947 if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
948 const Shape& instruction_subshape =
949 ShapeUtil::GetSubshape(instruction->shape(), index);
950 for (const LogicalBuffer* buffer : buffers) {
951 if (!Shape::Equal()
952 .IgnoreDynamicDimension()
953 .MinorToMajorOnlyInLayout()(instruction_subshape,
954 buffer->shape())) {
955 return InternalError(
956 "Layout of instruction %s at index {%s} does not match "
957 "source LogicalBuffer %s: %s vs %s",
958 instruction->name(), absl::StrJoin(index, ","),
959 buffer->ToString(),
960 ShapeUtil::HumanStringWithLayout(instruction_subshape),
961 ShapeUtil::HumanStringWithLayout(buffer->shape()));
962 }
963 }
964 }
965 return Status::OK();
966 }));
967
968 // Verify instructions that have special layout constraints.
969 switch (instruction->opcode()) {
970 case HloOpcode::kCall:
971 TF_RETURN_IF_ERROR(CheckCallLayout(
972 instruction,
973 FindOrDie(computation_layouts_, instruction->to_apply())));
974 break;
975 case HloOpcode::kCustomCall:
976 TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
977 break;
978 case HloOpcode::kFusion:
979 TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
980 break;
981 case HloOpcode::kParameter:
982 TF_RETURN_IF_ERROR(CheckParameterLayout(
983 instruction,
984 FindOrDie(computation_layouts_, instruction->parent())));
985 break;
986 case HloOpcode::kConstant:
987 TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
988 break;
989 case HloOpcode::kWhile:
990 TF_RETURN_IF_ERROR(CheckWhileLayout(
991 instruction,
992 FindOrDie(computation_layouts_, instruction->while_condition()),
993 FindOrDie(computation_layouts_, instruction->while_body())));
994 break;
995 case HloOpcode::kConditional: {
996 std::vector<ComputationLayout> branch_computation_layouts;
997 for (auto branch_computation : instruction->branch_computations()) {
998 branch_computation_layouts.emplace_back(
999 FindOrDie(computation_layouts_, branch_computation));
1000 }
1001 TF_RETURN_IF_ERROR(CheckConditionalLayout(
1002 instruction, absl::MakeSpan(branch_computation_layouts)));
1003 break;
1004 }
1005 default:
1006 break;
1007 }
1008 }
1009 }
1010 // Finally verify the result layout, if set, matches the layout of the entry
1011 // computation root.
1012 const ShapeLayout& result_layout =
1013 FindOrDie(computation_layouts_, module->entry_computation())
1014 .result_layout();
1015 if (result_layout.LayoutIsSet()) {
1016 TF_RET_CHECK(
1017 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1018 module->result_shape(), result_layout.shape()));
1019 }
1020 return Status::OK();
1021 }
1022
LayoutAssignment(ComputationLayout * entry_computation_layout,std::function<bool (const HloInstruction *)> instruction_can_change_layout_func,ChannelLayoutConstraints * channel_constraints)1023 LayoutAssignment::LayoutAssignment(
1024 ComputationLayout* entry_computation_layout,
1025 std::function<bool(const HloInstruction*)>
1026 instruction_can_change_layout_func,
1027 ChannelLayoutConstraints* channel_constraints)
1028 : entry_computation_layout_(entry_computation_layout),
1029
1030 saved_entry_computation_layout_(*entry_computation_layout),
1031 channel_layout_constraints_(channel_constraints),
1032 instruction_can_change_layout_func_(
1033 std::move(instruction_can_change_layout_func)) {
1034 if (channel_layout_constraints_ != nullptr) {
1035 // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1036 // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1037 channel_constraints_ = *channel_layout_constraints_;
1038 }
1039 VLOG(1) << "Entry computation layout given to layout assignment: "
1040 << entry_computation_layout_->ToString();
1041 }
1042
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64 operand_no)1043 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1044 const Layout& output_layout, const HloInstruction* instruction,
1045 int64 operand_no) {
1046 const HloInstruction* operand = instruction->operand(operand_no);
1047 CHECK(instruction->shape().IsArray());
1048 CHECK(operand->shape().IsArray());
1049 if (!ShapeUtil::IsScalar(operand->shape()) &&
1050 operand->shape().rank() == instruction->shape().rank() &&
1051 !instruction_can_change_layout_func_(instruction)) {
1052 // Propagate the result layout to the operand layout if the instruction
1053 // requires the same layout out for the result and the operand.
1054 //
1055 // For elementwise operations, using the same layout for the operands and
1056 // the result also has the following benefits:
1057 // 1) the elementwise operation can reuse its operand's buffer, and
1058 // 2) the input and output elements can reuse the same linear index.
1059 return absl::make_unique<Layout>(output_layout);
1060 }
1061
1062 if (instruction->opcode() == HloOpcode::kReshape) {
1063 // Prefer the operand layout that makes the reshape an bitcast. If any
1064 // dimension bound is 1 in the operand shape, there may be several such
1065 // layouts. So if 'output_layout' is the default layout, try if the
1066 // reshape is a bitcast when using the same layout. This may avoid copy
1067 // operations. For similar reasons, if the operand and output have the same
1068 // rank, try to match the operand's layout to the output.
1069 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1070 ShapeUtil::TrueRank(instruction->shape()) == 1) {
1071 // Don't assign a layout in case of R1 -> effective R1 reshape.
1072 return nullptr;
1073 }
1074
1075 const Shape& output_shape = instruction->shape();
1076 Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1077 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1078 LayoutUtil::MinorToMajor(output_layout));
1079 Shape operand_shape = operand->shape();
1080 *operand_shape.mutable_layout() =
1081 LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1082 auto aligned_operand_shape =
1083 ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1084 if (aligned_operand_shape) {
1085 auto operand_layout = aligned_operand_shape.value().layout();
1086 TF_CHECK_OK(
1087 LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1088 return absl::make_unique<Layout>(operand_layout);
1089 }
1090 }
1091
1092 if (instruction->opcode() == HloOpcode::kTranspose) {
1093 // Pick the operand layout that makes the transpose a bitcast.
1094 int64 rank = instruction->shape().rank();
1095 std::vector<int64> new_minor_to_major(rank);
1096 for (int64 i = 0; i < rank; ++i) {
1097 int64 output_dim = LayoutUtil::Minor(output_layout, i);
1098 int64 operand_dim = instruction->dimensions(output_dim);
1099 new_minor_to_major[i] = operand_dim;
1100 }
1101 Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1102 TF_CHECK_OK(
1103 LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1104 return absl::make_unique<Layout>(operand_layout);
1105 }
1106
1107 return nullptr;
1108 }
1109
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64 operand_no)1110 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1111 const Layout& operand_layout, const HloInstruction* user,
1112 int64 operand_no) {
1113 const HloInstruction* operand = user->operand(operand_no);
1114
1115 CHECK(user->shape().IsArray() && operand->shape().IsArray());
1116
1117 if (!ShapeUtil::IsScalar(operand->shape()) &&
1118 operand->shape().rank() == user->shape().rank() &&
1119 !instruction_can_change_layout_func_(user)) {
1120 // Assign users the same layout as the operand.
1121 return absl::make_unique<Layout>(operand_layout);
1122 }
1123
1124 if (user->opcode() == HloOpcode::kReshape) {
1125 // Prefer the user layout that makes the reshape an bitcast. If any
1126 // dimension bound is 1 in the user shape, there may be several such
1127 // layouts. So if 'operand_layout' is the default layout, try if the
1128 // reshape is a bitcast when using the same layout. This may avoid copy
1129 // operations. For similar reasons, if the operand and output have the same
1130 // rank, try to match the outputs's layout to the operand.
1131 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1132 ShapeUtil::TrueRank(user->shape()) == 1) {
1133 // Don't assign a layout in case of R1 -> effective R1 reshape.
1134 return nullptr;
1135 }
1136 Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1137 operand->shape().element_type(),
1138 AsInt64Slice(operand->shape().dimensions()),
1139 LayoutUtil::MinorToMajor(operand_layout));
1140 Shape output_shape = user->shape();
1141 *output_shape.mutable_layout() =
1142 LayoutUtil::GetDefaultLayoutForShape(output_shape);
1143 auto aligned_user_shape =
1144 ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1145 if (aligned_user_shape) {
1146 auto user_layout = aligned_user_shape.value().layout();
1147 TF_CHECK_OK(
1148 LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1149 return absl::make_unique<Layout>(user_layout);
1150 }
1151 }
1152
1153 if (user->opcode() == HloOpcode::kTranspose) {
1154 // Pick the user layout that makes the transpose a bitcast.
1155 int64 rank = user->shape().rank();
1156 std::vector<int64> new_minor_to_major(rank);
1157 auto inverse_dimensions = InversePermutation(user->dimensions());
1158 for (int64 i = 0; i < rank; ++i) {
1159 int64 operand_dim = LayoutUtil::Minor(operand_layout, i);
1160 int64 user_dim = inverse_dimensions[operand_dim];
1161 new_minor_to_major[i] = user_dim;
1162 }
1163 Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1164 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1165 return absl::make_unique<Layout>(user_layout);
1166 }
1167
1168 return nullptr;
1169 }
1170
PropagateConstraints(LayoutConstraints * constraints)1171 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1172 // Gathers all initial constraints in a worklist and propagates them in
1173 // depth-first order. DFS order seems to be better than BFS because a
1174 // constraint is propagated as far as possible before propagating unrelated
1175 // constraints which makes it less likely that conflicting constraints will be
1176 // propagated to instructions. However, we should experiment with other orders
1177 // too.
1178 std::deque<const LayoutConstraint*> worklist;
1179
1180 // Lambda for moving newly added constraints to the worklist.
1181 auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1182 // Add constraints to the front of the deque for DFS ordering.
1183 for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1184 if (constraint->dfs()) {
1185 worklist.push_front(constraint);
1186 } else {
1187 worklist.push_back(constraint);
1188 }
1189 }
1190 };
1191 add_new_constraints_to_worklist();
1192
1193 while (!worklist.empty()) {
1194 const LayoutConstraint* layout_constraint = worklist.front();
1195 worklist.pop_front();
1196 VLOG(2) << "Propagating " << layout_constraint->ToString()
1197 << " to its neighbors.";
1198 if (auto* buffer_constraint =
1199 dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1200 TF_RETURN_IF_ERROR(
1201 PropagateBufferConstraint(*buffer_constraint, constraints));
1202 } else if (auto* operand_constraint =
1203 dynamic_cast<const OperandLayoutConstraint*>(
1204 layout_constraint)) {
1205 TF_RETURN_IF_ERROR(
1206 PropagateOperandConstraint(*operand_constraint, constraints));
1207 } else if (auto* result_constraint =
1208 dynamic_cast<const ResultLayoutConstraint*>(
1209 layout_constraint)) {
1210 TF_RETURN_IF_ERROR(
1211 PropagateResultConstraint(*result_constraint, constraints));
1212 } else {
1213 LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1214 }
1215
1216 add_new_constraints_to_worklist();
1217 }
1218 return Status::OK();
1219 }
1220
1221 namespace {
1222
1223 // Returns a vector containing all array-shaped uses (instruction and operand
1224 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1225 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1226 const LogicalBuffer& buffer,
1227 const TuplePointsToAnalysis& points_to_analysis) {
1228 CHECK(buffer.IsArray());
1229 std::vector<std::pair<const HloInstruction*, int64>> uses;
1230 for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1231 if (!buffer_alias.instruction()->shape().IsArray()) {
1232 continue;
1233 }
1234 // This alias must be the top-level (index == {}) of the instruction's
1235 // result because the instruction produces an array.
1236 CHECK(buffer_alias.index().empty());
1237
1238 // Add all uses of the instruction's output.
1239 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1240 for (int64 operand_no :
1241 user->OperandIndices(buffer_alias.instruction())) {
1242 uses.emplace_back(user, operand_no);
1243 }
1244 }
1245 }
1246 return uses;
1247 }
1248
1249 } // namespace
1250
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1251 Status LayoutAssignment::PropagateUseConstraintToDefs(
1252 const ShapeLayout& shape_layout, const HloInstruction* instruction,
1253 LayoutConstraints* constraints) {
1254 // Try to set all logical buffers which may be sources of the given operand to
1255 // match the given layout.
1256 const PointsToSet& points_to_set =
1257 constraints->points_to_analysis().GetPointsToSet(instruction);
1258 return points_to_set.ForEachElementWithStatus(
1259 [&shape_layout, constraints](
1260 const ShapeIndex& index,
1261 const PointsToSet::BufferList& buffers) -> Status {
1262 if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1263 for (const LogicalBuffer* buffer : buffers) {
1264 if (constraints->BufferLayout(*buffer) == nullptr &&
1265 buffer->shape().IsArray()) {
1266 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1267 ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1268 *buffer, /*mandatory=*/true));
1269 }
1270 }
1271 }
1272 return Status::OK();
1273 });
1274 }
1275
1276 namespace {
1277 // A transpose or a reshape that only changes trivial dimensions have meaningful
1278 // layouts that are valuable to propagate in a depthfirst manner to avoid
1279 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo,bool forward_propagation=true)1280 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
1281 bool forward_propagation = true) {
1282 switch (hlo.opcode()) {
1283 case HloOpcode::kFusion:
1284 return hlo.IsCustomFusion();
1285 case HloOpcode::kGather:
1286 return true;
1287 case HloOpcode::kReshape:
1288 return hlo.operand(0)->shape().rank() == 1 ||
1289 (forward_propagation &&
1290 std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
1291 case HloOpcode::kScatter:
1292 case HloOpcode::kTranspose:
1293 return true;
1294 default:
1295 return false;
1296 }
1297 }
1298
1299 } // namespace
1300
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1301 Status LayoutAssignment::PropagateOperandConstraint(
1302 const OperandLayoutConstraint& operand_constraint,
1303 LayoutConstraints* constraints) {
1304 // Try to set the layout of the logical buffers in the given operand to match
1305 // the constrained layout. This avoids copies.
1306 TF_RETURN_IF_ERROR(
1307 PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1308 operand_constraint.operand(), constraints));
1309
1310 // For array-shaped operands and user instructions try to pick a minimum cost
1311 // layout. For example, if the operand of an elementwise instruction is
1312 // constrained to a certain layout we want the output of the instruction to
1313 // have the same layout.
1314 //
1315 // If the user is not array-shaped, we still want to propagate the layout
1316 // to siblings if the instruction can't change layout. This is to represent
1317 // the information that non-layout-changing instructions should have the same
1318 // layout for the operands with the same ranks.
1319 const HloInstruction* operand = operand_constraint.operand();
1320 const HloInstruction* user = operand_constraint.instruction();
1321 if (!operand->shape().IsArray()) {
1322 return Status::OK();
1323 }
1324
1325 if (user->opcode() == HloOpcode::kAllReduce) {
1326 const auto shape_index =
1327 user->operand_count() == 1
1328 ? ShapeIndex()
1329 : ShapeIndex({operand_constraint.operand_no()});
1330 TF_ASSIGN_OR_RETURN(const LogicalBuffer* buffer,
1331 constraints->points_to_analysis().GetBufferDefinedAt(
1332 user, shape_index));
1333 const BufferLayoutConstraint* constraint =
1334 constraints->GetBufferLayoutConstraint(*buffer);
1335 if (constraint == nullptr) {
1336 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1337 operand_constraint.shape_layout().layout(), *buffer,
1338 /*mandatory=*/false));
1339 }
1340 }
1341 if (instruction_can_change_layout_func_(user) && !user->shape().IsArray()) {
1342 return Status::OK();
1343 }
1344
1345 // Only try to choose a low cost layout if the instruction 'user' defines its
1346 // output (ie, doesn't forward a buffer from elsewhere).
1347 if (constraints->OperandBufferForwarded(user,
1348 operand_constraint.operand_no())) {
1349 return Status::OK();
1350 }
1351
1352 int64 operand_rank = operand->shape().rank();
1353 if (operand_rank <= 1) {
1354 return Status::OK();
1355 }
1356
1357 // Propagate layouts between operands of the same instruction. This is a
1358 // constraint on non-layout-changing instructions.
1359 if (!instruction_can_change_layout_func_(user)) {
1360 // Make sure all siblings have the same layout as the operand.
1361 for (int64 operand_no = 0; operand_no < user->operand_count();
1362 ++operand_no) {
1363 if (user->operand(operand_no) == operand) {
1364 continue;
1365 }
1366 const HloInstruction* sibling = user->operand(operand_no);
1367 const int64 sibling_rank = sibling->shape().rank();
1368 if (sibling_rank <= 1) {
1369 continue;
1370 }
1371 if (operand_rank != sibling_rank) {
1372 continue;
1373 }
1374 const OperandLayoutConstraint* constraint =
1375 constraints->GetOperandLayoutConstraint(user, operand_no);
1376 if (constraint != nullptr) {
1377 // Due to the DFS of the propagation we can end up here when operand_no
1378 // has a layout set that hasn't been propagated yet (is still on the
1379 // stack of layouts to propagate).
1380 // We can continue here and leave the operands with different layouts,
1381 // as we will either:
1382 // - overwrite the current operand when the DFS gets back to propagating
1383 // operand(operand_no) to its siblings
1384 // - overwrite operand(operand_no)'s layout with a mandatory layout if
1385 // we continue to propagate our layout to the result, and then
1386 // backwards into all operands (if the result is an array of rank > 1)
1387 continue;
1388 }
1389 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1390 operand_constraint.shape_layout().layout(), user, operand_no,
1391 /*mandatory=*/false));
1392 }
1393 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1394 user->shape(),
1395 [&](const Shape& subshape, const ShapeIndex& shape_index) {
1396 if (subshape.IsTuple()) {
1397 return Status::OK();
1398 }
1399 if (subshape.rank() <= 1) {
1400 return Status::OK();
1401 }
1402
1403 // Assign the right layout to input fusion of higher rank reduce
1404 // operations.
1405 if (subshape.rank() != operand->shape().rank()) {
1406 return Status::OK();
1407 }
1408 // TODO(b/67641796): Are there cases except fusion that use this code
1409 // path?
1410 TF_ASSIGN_OR_RETURN(
1411 const LogicalBuffer* buffer,
1412 constraints->points_to_analysis().GetBufferDefinedAt(
1413 user, shape_index));
1414 // Make sure the output has the same layout as the operand.
1415 const BufferLayoutConstraint* constraint =
1416 constraints->GetBufferLayoutConstraint(*buffer);
1417 // If we already have a constraint for the buffer it was assigned but
1418 // hasn't propagated yet. This can happen with diamond-shaped graphs
1419 // where one path is first evaluated in depth-first order (we're here)
1420 // and the other path is propagated later. We don't set the layout
1421 // here as it will always be overwritten later.
1422 if (constraint == nullptr) {
1423 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1424 operand_constraint.shape_layout().layout(), *buffer,
1425 /*mandatory=*/false));
1426 }
1427 return Status::OK();
1428 }));
1429 return Status::OK();
1430 }
1431 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1432 user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1433 if (subshape.IsTuple()) {
1434 return Status::OK();
1435 }
1436 if (subshape.rank() <= 1) {
1437 return Status::OK();
1438 }
1439 TF_ASSIGN_OR_RETURN(
1440 const LogicalBuffer* buffer,
1441 constraints->points_to_analysis().GetBufferDefinedAt(user,
1442 shape_index));
1443 if (constraints->BufferLayout(*buffer) == nullptr ||
1444 !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1445 std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1446 operand_constraint.shape_layout().layout(), user,
1447 operand_constraint.operand_no());
1448 if (layout != nullptr) {
1449 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1450 *layout, *buffer,
1451 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1452 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1453 }
1454 }
1455 return Status::OK();
1456 }));
1457 return Status::OK();
1458 }
1459
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1460 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1461 const BufferLayoutConstraint& buffer_constraint,
1462 LayoutConstraints* constraints) {
1463 VLOG(5) << "PropagateBufferConstraintToOperands: "
1464 << buffer_constraint.ToString();
1465 const LogicalBuffer& buffer = buffer_constraint.buffer();
1466
1467 const HloInstruction* instruction = buffer.instruction();
1468 if (IsAtMostRank1(instruction->shape())) {
1469 return Status::OK();
1470 }
1471
1472 if (instruction->opcode() == HloOpcode::kAllReduce) {
1473 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1474 buffer_constraint.layout(), instruction,
1475 instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1476 /*mandatory=*/true));
1477 return Status::OK();
1478 }
1479 for (int64 operand_no = 0; operand_no < instruction->operand_count();
1480 ++operand_no) {
1481 const HloInstruction* operand = instruction->operand(operand_no);
1482 if (IsAtMostRank1(operand->shape())) {
1483 continue;
1484 }
1485 if (!instruction_can_change_layout_func_(instruction)) {
1486 // Copy the layout to the operand.
1487 if (buffer.IsArray() && operand->shape().IsArray() &&
1488 operand->shape().rank() ==
1489 LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1490 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1491 buffer_constraint.layout(), instruction, operand_no,
1492 /*mandatory=*/true));
1493 }
1494 } else {
1495 if (!buffer.IsTopLevel() ||
1496 !instruction->operand(operand_no)->shape().IsArray()) {
1497 continue; // Don't touch buffers that are internal to a tuple.
1498 }
1499 VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1500 << instruction->ToShortString();
1501 // Assign a layout if there is no constraint already.
1502 const OperandLayoutConstraint* constraint =
1503 constraints->GetOperandLayoutConstraint(instruction, operand_no);
1504 if (constraint == nullptr || !constraint->mandatory()) {
1505 std::unique_ptr<Layout> operand_layout =
1506 ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1507 instruction, operand_no);
1508 if (operand_layout != nullptr) {
1509 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1510 *operand_layout, instruction, operand_no, /*mandatory=*/false,
1511 /*dfs=*/
1512 InstructionShouldPropagateDepthFirst(
1513 *instruction, /*forward_propagation=*/false)));
1514 }
1515 } else {
1516 VLOG(6) << "Operand already has a constraint "
1517 << constraint->ToString();
1518 }
1519 }
1520 }
1521 return Status::OK();
1522 }
1523
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1524 Status LayoutAssignment::PropagateBufferConstraint(
1525 const BufferLayoutConstraint& buffer_constraint,
1526 LayoutConstraints* constraints) {
1527 // Only propagate array layouts.
1528 const LogicalBuffer& buffer = buffer_constraint.buffer();
1529 if (!buffer.IsArray()) {
1530 return Status::OK();
1531 }
1532 TF_RETURN_IF_ERROR(
1533 PropagateBufferConstraintToUses(buffer_constraint, constraints));
1534 return PropagateBufferConstraintToOperands(buffer_constraint, constraints);
1535 }
1536
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1537 Status LayoutAssignment::PropagateBufferConstraintToUses(
1538 const BufferLayoutConstraint& buffer_constraint,
1539 LayoutConstraints* constraints) {
1540 const LogicalBuffer& buffer = buffer_constraint.buffer();
1541 TF_RET_CHECK(buffer.IsArray());
1542
1543 // Propagate the layout to all array uses of the logical buffer. This skips
1544 // uses of the buffer where the buffer is the element of a tuple.
1545 for (const auto& user_operand_no :
1546 GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1547 const HloInstruction* user = user_operand_no.first;
1548 int64 operand_no = user_operand_no.second;
1549 // Only add an operand constraint if the user does not forward the buffer
1550 // because this case is not handled is SetOperandLayout.
1551 if (constraints->OperandLayout(user, operand_no) == nullptr &&
1552 !constraints->OperandBufferForwarded(user, operand_no)) {
1553 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1554 buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1555 }
1556 }
1557
1558 // Propagate to backedges of kWhile.
1559 CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1560 if (node.caller_callsites().size() != 1) {
1561 return Status::OK();
1562 }
1563 const HloInstruction* parent = node.caller_callsites()[0].instruction();
1564 if (parent->opcode() != HloOpcode::kWhile) {
1565 return Status::OK();
1566 }
1567
1568 for (HloInstruction* user : buffer.instruction()->users()) {
1569 if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1570 continue;
1571 }
1572 if (user->parent()->root_instruction() == user) {
1573 VLOG(3) << "Propagating layout through backedge"
1574 << buffer_constraint.layout().ToString();
1575 int64 index = user->operand_index(buffer.instruction());
1576 TF_ASSIGN_OR_RETURN(
1577 auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
1578 user->parent()->parameter_instruction(0), {index}));
1579
1580 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1581 buffer_constraint.layout(), *buffer, /*mandatory=*/false));
1582 }
1583 }
1584
1585 return Status::OK();
1586 }
1587
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1588 Status LayoutAssignment::PropagateResultConstraint(
1589 const ResultLayoutConstraint& layout_constraint,
1590 LayoutConstraints* constraints) {
1591 // Propagate the use constraint of the root instruction up to the logical
1592 // buffers which make up the result.
1593 return PropagateUseConstraintToDefs(
1594 layout_constraint.shape_layout(),
1595 constraints->computation()->root_instruction(), constraints);
1596 }
1597
1598 namespace {
1599
1600 // Infers the layout of the array at the given index in the given instruction's
1601 // output using points-to analysis. Precondition: The given instruction must
1602 // not produce this array value (that is, the array is forwarded from the
1603 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1604 StatusOr<Layout> InferArrayLayout(
1605 const TuplePointsToAnalysis& points_to_analysis,
1606 HloInstruction* instruction, const ShapeIndex& index) {
1607 // This function should only be called for array shapes which don't yet have
1608 // layouts.
1609 const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1610 TF_RET_CHECK(subshape.IsArray());
1611 TF_RET_CHECK(!subshape.has_layout());
1612
1613 // The instruction should not define the buffer at this index.
1614 TF_RET_CHECK(
1615 !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1616 << instruction->ToString();
1617
1618 const auto& source_buffers =
1619 points_to_analysis.GetPointsToSet(instruction).element(index);
1620 TF_RET_CHECK(!source_buffers.empty());
1621
1622 // Verify the layout is the same for every LogicalBuffer which this location
1623 // ('instruction' and 'index') points to.
1624 const Layout* first_buffer_layout = nullptr;
1625 for (const LogicalBuffer* source_buffer : source_buffers) {
1626 if (!source_buffer->shape().has_layout()) {
1627 // This should not happen because we've assigned layouts to all
1628 // instructions preceding this one.
1629 return InternalError("LogicalBuffer %s does not have a layout",
1630 source_buffer->ToString());
1631 }
1632
1633 if (first_buffer_layout == nullptr) {
1634 first_buffer_layout = &source_buffer->shape().layout();
1635 } else if (!Layout::Equal().MinorToMajorOnly()(
1636 source_buffer->shape().layout(), *first_buffer_layout)) {
1637 // The points-to set is ambiguous for this index and the different source
1638 // buffers have different layouts. This case is possible in valid XLA
1639 // computations because we do not propagate BufferLayoutConstraints to all
1640 // LogicalBuffers which may alias the constrained LogicalBuffer at some
1641 // point in the computation.
1642 return FailedPrecondition(
1643 "Array at index {%s} in instruction %s aliases buffers %s "
1644 "and %s which have different layouts",
1645 absl::StrJoin(index, ","), instruction->name(),
1646 source_buffers[0]->ToString(), source_buffer->ToString());
1647 }
1648 }
1649
1650 return *first_buffer_layout;
1651 }
1652
1653 // For fusion instructions, set the layout of each fused parameter instruction
1654 // to match the layout of its corresponding fusion instruction operand. Also,
1655 // set the layout of the fused root to match the layout of the fusion
1656 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1657 Status SetFusionLayouts(HloInstruction* fusion) {
1658 TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1659 for (auto* fused_instruction :
1660 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1661 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1662 const HloInstruction* fusion_operand =
1663 fusion->operand(fused_instruction->parameter_number());
1664 DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1665 fused_instruction->shape()));
1666 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1667 fusion_operand->shape(), fused_instruction->mutable_shape()));
1668 } else if (fused_instruction == fusion->fused_expression_root()) {
1669 // The layout of the root of the fused expression must match the fusion
1670 // instruction layout.
1671 DCHECK(
1672 ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1673 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1674 fusion->shape(), fused_instruction->mutable_shape()));
1675 } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1676 // A GTE inherits its layout from its operand (which should ultimately be
1677 // a parameter).
1678 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1679 fused_instruction->operand(0)->shape().tuple_shapes(
1680 fused_instruction->tuple_index()),
1681 fused_instruction->mutable_shape()));
1682 } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1683 // Give constants the layout of their literal.
1684 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1685 fused_instruction->literal().shape(),
1686 fused_instruction->mutable_shape()));
1687 } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1688 // Nop; leave the infeed layout alone.
1689 } else if (!fusion->IsCustomFusion()) {
1690 // Other instructions don't have layouts inside of fusion nodes.
1691 // But do not clear layouts for other instructions in custom fusion nodes.
1692 LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1693 }
1694 }
1695
1696 return Status::OK();
1697 }
1698
1699 } // namespace
1700
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1701 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1702 HloComputation* computation) {
1703 VLOG(2) << "Assigning layouts to computation: " << computation->name();
1704 XLA_VLOG_LINES(2, computation->ToString());
1705 XLA_VLOG_LINES(2, constraints.ToString());
1706
1707 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1708 LayoutUtil::ClearLayout(instruction->mutable_shape());
1709
1710 // Set the layouts of the array shapes this instruction defines as indicated
1711 // by the respective BufferLayoutConstraints. Any array shapes in the output
1712 // of the instruction which are not defined by the instruction (eg, array
1713 // elements in a Tuple instruction) will be assigned below via inference.
1714 for (const LogicalBuffer* buffer :
1715 constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1716 instruction)) {
1717 if (!buffer->shape().IsArray()) {
1718 continue;
1719 }
1720
1721 TF_RET_CHECK(buffer->instruction() == instruction);
1722 const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1723 TF_RET_CHECK(buffer_layout != nullptr);
1724
1725 if (instruction->opcode() == HloOpcode::kConstant) {
1726 // For constants, we also need to change the layout of the internal
1727 // literal.
1728 instruction->RelayoutConstant(*buffer_layout, buffer->index());
1729 } else {
1730 Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1731 instruction->mutable_shape(), buffer->index());
1732 *buffer_subshape->mutable_layout() = *buffer_layout;
1733 }
1734 }
1735
1736 // Any remaining layouts in the output of the instruction must be
1737 // inferrable using points-to analysis.
1738 TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1739 instruction->mutable_shape(),
1740 [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1741 if (subshape->has_layout() || !subshape->IsArray()) {
1742 return Status::OK();
1743 }
1744 // Set Layout of subshape to match layout of LogicalBuffer which
1745 // produces it.
1746 TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1747 InferArrayLayout(constraints.points_to_analysis(),
1748 instruction, index));
1749 return Status::OK();
1750 }));
1751
1752 // Create a copy of an operand if the operand instruction's layout does not
1753 // match the use constraint (OperandLayoutConstraint).
1754 for (int64 operand_no = 0; operand_no < instruction->operand_count();
1755 ++operand_no) {
1756 const ShapeLayout* operand_layout =
1757 constraints.OperandLayout(instruction, operand_no);
1758 if (operand_layout != nullptr) {
1759 TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1760 instruction, operand_no));
1761 }
1762 }
1763
1764 // Fusion instructions require some layouts to be set on fused instructions
1765 // inside the fusion instruction.
1766 if (instruction->opcode() == HloOpcode::kFusion) {
1767 TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1768 }
1769
1770 // Execute extra verification step once the layout has been finalized.
1771 TF_RETURN_IF_ERROR(Verify(instruction));
1772
1773 // Shape must be valid.
1774 TF_RETURN_IF_ERROR(
1775 ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1776
1777 // Verify all layouts in the shape have been set.
1778 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1779 }
1780 return Status::OK();
1781 }
1782
CalculateComputationLayout(HloComputation * computation)1783 Status LayoutAssignment::CalculateComputationLayout(
1784 HloComputation* computation) {
1785 ComputationLayout computation_layout(computation->ComputeProgramShape(),
1786 /*ignore_layouts=*/false);
1787 InsertOrDie(&computation_layouts_, computation, computation_layout);
1788 VLOG(2) << " Calculated ComputationLayout = "
1789 << computation_layout.ToString();
1790 return Status::OK();
1791 }
1792
ClearComputationLayouts(HloComputation * computation)1793 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1794 // Clear existing layouts of the instructions. All layouts must be assigned
1795 // by the LayoutAssignment pass, except for those on parameters, the
1796 // computation result, and a couple special cases. The former two are
1797 // specified in computation_layout. Clearing the layouts here avoids hiding
1798 // potential bugs in the layout assignment pass that may accidentally use the
1799 // existing layout.
1800 for (HloInstruction* instruction : computation->instructions()) {
1801 if (instruction->opcode() == HloOpcode::kBitcast) {
1802 // bitcasts are inherently layout sensitive and so a bitcast instruction
1803 // present in the IR before layout assignment is a bug.
1804 return InternalError(
1805 "Unexpected bitcast operation seen during layout assignment: %s.",
1806 instruction->ToString());
1807 }
1808 // Some instructions carry mandatory layouts in their shape.
1809 if (instruction->opcode() != HloOpcode::kInfeed &&
1810 !IsLayoutConstrainedCustomCall(instruction) &&
1811 !IsLayoutConstrainedAllReduce(instruction)) {
1812 LayoutUtil::ClearLayout(instruction->mutable_shape());
1813 }
1814 }
1815 return Status::OK();
1816 }
1817
RunOnComputation(ComputationLayout * computation_layout,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1818 Status LayoutAssignment::RunOnComputation(
1819 ComputationLayout* computation_layout, HloComputation* computation,
1820 ChannelLayoutConstraints* channel_constraints) {
1821 VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1822 << ")";
1823
1824 // Must be run before clearing layouts.
1825 TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1826
1827 TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1828 if (computation_layout != nullptr) {
1829 auto it = computation_layouts_.find(computation);
1830 if (it == computation_layouts_.end()) {
1831 VLOG(2) << " New ComputationLayout = " << computation_layout->ToString();
1832 computation_layouts_.emplace(computation, *computation_layout);
1833 } else {
1834 TF_RET_CHECK(computation_layout == &it->second ||
1835 computation_layout == entry_computation_layout_);
1836 VLOG(2) << " Existing ComputationLayout = "
1837 << computation_layout->ToString();
1838 }
1839 } else {
1840 VLOG(2) << " No ComputationLayout specified (will be calculated)";
1841 }
1842
1843 // Construct LayoutConstraints with all layout constraints of the computation.
1844 LayoutConstraints constraints(*points_to_analysis_, computation);
1845
1846 // Add constraints required for correctness on all backends (eg, entry
1847 // parameter layout constraints).
1848 TF_RETURN_IF_ERROR(AddMandatoryConstraints(
1849 computation_layout, channel_constraints, computation, &constraints));
1850
1851 // Add any backend-specific constraints.
1852 TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
1853
1854 // Propagates layouts from mandatory and backend constraints.
1855 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1856
1857 // Prior to applying default layouts, we take note of all HLO instructions
1858 // which lack a layout constraint.
1859 for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
1860 unconstrained_layout_instructions_.insert(
1861 points_to_analysis_->GetBuffer(buffer_id).instruction());
1862 }
1863
1864 // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
1865 // layout and propagate the change.
1866 while (!constraints.unconstrained_buffer_ids().empty()) {
1867 int unconstrained_count = constraints.unconstrained_buffer_ids().size();
1868
1869 // Arbitrarily pick the first unconstrained buffer and give it the default
1870 // layout (or the literal layout, in case of constants). By construction
1871 // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
1872 const LogicalBuffer& buffer = points_to_analysis_->GetBuffer(
1873 *constraints.unconstrained_buffer_ids().begin());
1874 const HloInstruction* instruction = buffer.instruction();
1875 Layout new_layout =
1876 instruction->opcode() == HloOpcode::kConstant
1877 ? ShapeUtil::GetSubshape(instruction->literal().shape(),
1878 buffer.index())
1879 .layout()
1880 : LayoutUtil::GetDefaultLayoutForShape(buffer.shape());
1881 TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
1882 /*mandatory=*/false));
1883
1884 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
1885
1886 // To verify progress has been made, check that the number of unconstrained
1887 // buffers has been reduced.
1888 CHECK_LT(constraints.unconstrained_buffer_ids().size(),
1889 unconstrained_count);
1890 }
1891 // All logical buffers should have constraints at this point. All that
1892 // remains is assign the constraints to the buffers and infer layouts for
1893 // aliased buffers.
1894 TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
1895
1896 // If the computation layout wasn't specified, now it is the time to compute
1897 // it according to the parameters and root instruction layouts.
1898 // This allows the first pass through this API to record the best flowing
1899 // layout to parameters and root instruction.
1900 if (computation_layout == nullptr) {
1901 TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
1902 }
1903
1904 // Record the layouts assigned for any communication ops in
1905 // channel_constraints so that they are constrained for future modules.
1906 if (channel_constraints != nullptr) {
1907 TF_RETURN_IF_ERROR(
1908 ConstrainChannelLayouts(computation, channel_constraints));
1909 }
1910
1911 // Copy the root instruction's result if its layout does not match the result
1912 // layout constraint.
1913 if (constraints.ResultLayout() != nullptr &&
1914 !constraints.ResultLayout()->MatchesLayoutInShape(
1915 computation->root_instruction()->shape(),
1916 /*minor_to_major_only=*/true)) {
1917 if (conditional_mismatch_.count(computation) > 0) {
1918 *FindOrDie(computation_layouts_, computation).mutable_result_layout() =
1919 FindOrDie(conditional_mismatch_, computation).result_layout();
1920 }
1921 TF_ASSIGN_OR_RETURN(
1922 HloInstruction * new_root,
1923 CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
1924 computation->root_instruction()));
1925 computation->set_root_instruction(new_root);
1926 }
1927 return Status::OK();
1928 }
1929
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1930 Status LayoutAssignment::ConstrainChannelLayouts(
1931 HloComputation* computation,
1932 ChannelLayoutConstraints* channel_constraints) {
1933 auto get_channel_constraints = [&](const HloInstruction* instruction) {
1934 return IsHostSendRecv(instruction) ? &host_channel_constraints_
1935 : channel_constraints;
1936 };
1937 // We go through the kRecvDone before. These must either impose their layout,
1938 // or find a matching one already existing (ConstrainChannel() returns
1939 // nullptr).
1940 for (HloInstruction* instruction : computation->instructions()) {
1941 if (instruction->opcode() == HloOpcode::kRecvDone) {
1942 const Layout* layout =
1943 get_channel_constraints(instruction)
1944 ->ConstrainChannel(
1945 *instruction->channel_id(),
1946 ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
1947 TF_RET_CHECK(layout == nullptr)
1948 << instruction->ToString()
1949 << " cannot constrain layout as it was set to "
1950 << LayoutUtil::HumanString(*layout);
1951 }
1952 }
1953 // After that we go through the kSend. These are likely going to have a kCopy
1954 // as operand (otherwise we add it), so in case the constrained layout does
1955 // not match, we can change the kCopy layout (and the kSend one as well).
1956 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1957 if (instruction->opcode() == HloOpcode::kSend) {
1958 HloInstruction* operand = instruction->mutable_operand(0);
1959 get_channel_constraints(instruction)
1960 ->ConstrainChannel(*instruction->channel_id(),
1961 operand->shape().layout());
1962 } else if (instruction->IsCrossModuleAllReduce()) {
1963 get_channel_constraints(instruction)
1964 ->ConstrainChannel(instruction->channel_id().value(),
1965 instruction->shape().layout());
1966 }
1967 }
1968 return Status::OK();
1969 }
1970
PropagateMemorySpace(HloModule * module)1971 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
1972 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
1973 for (auto buffer : alias_analysis->buffers()) {
1974 // First go through values to collect the memory spaces.
1975 int64 buffer_memory_space = Layout::kDefaultMemorySpace;
1976 for (auto value : buffer.values()) {
1977 const Shape& defining_shape = value->defining_position().shape();
1978 int64 memory_space = defining_shape.layout().memory_space();
1979 if (memory_space != Layout::kDefaultMemorySpace) {
1980 if (buffer_memory_space != Layout::kDefaultMemorySpace &&
1981 memory_space != buffer_memory_space) {
1982 return InternalError(
1983 "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
1984 buffer.id(), value->ToShortString(), buffer_memory_space,
1985 memory_space);
1986 }
1987 buffer_memory_space = memory_space;
1988 }
1989 }
1990
1991 // If we encounter a memory space other than the default, then propagate all
1992 // the positions with the buffer's memory space.
1993 if (buffer_memory_space != Layout::kDefaultMemorySpace) {
1994 for (auto value : buffer.values()) {
1995 for (auto& position : value->positions()) {
1996 Shape* shape = ShapeUtil::GetMutableSubshape(
1997 position.instruction->mutable_shape(), position.index);
1998 shape->mutable_layout()->set_memory_space(buffer_memory_space);
1999 }
2000 }
2001 }
2002 }
2003 return Status::OK();
2004 }
2005
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2006 Status LayoutAssignment::PropagateComputationLayouts(
2007 HloComputation* computation, ComputationLayout* computation_layout) {
2008 ComputationLayout computed_computation_layout(
2009 computation->ComputeProgramShape(),
2010 /*ignore_layouts=*/false);
2011 for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2012 ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2013 bool needs_assign = false;
2014 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2015 param_layout->shape(),
2016 [&](const Shape& subshape, const ShapeIndex& shape_index) {
2017 if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2018 return Status::OK();
2019 }
2020 if (!subshape.has_layout()) {
2021 needs_assign = true;
2022 return Status::OK();
2023 }
2024 const auto& computed_subshape = ShapeUtil::GetSubshape(
2025 computed_computation_layout.parameter_shape(i), shape_index);
2026 if (subshape.layout() != computed_subshape.layout()) {
2027 return InternalError(
2028 "Assigned parameter shape %s does not match layout of "
2029 "computation shape: %s",
2030 computed_computation_layout.ToString(),
2031 computation_layout->ToString());
2032 }
2033 return Status::OK();
2034 }));
2035 if (needs_assign) {
2036 VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2037 << computation->name() << ": "
2038 << computed_computation_layout.parameter_layout(i).ToString();
2039 *param_layout = computed_computation_layout.parameter_layout(i);
2040 }
2041 }
2042 ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2043 if (!result_layout->LayoutIsSet()) {
2044 VLOG(4) << "Assigning result layout of computation " << computation->name()
2045 << ": " << computed_computation_layout.result_layout().ToString();
2046 *result_layout = computed_computation_layout.result_layout();
2047 } else {
2048 TF_RET_CHECK(
2049 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2050 computed_computation_layout.result_layout().shape(),
2051 result_layout->shape()));
2052 }
2053 return Status::OK();
2054 }
2055
Run(HloModule * module)2056 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
2057 VLOG(2) << "Running layout assignment on module " << module->name();
2058 TF_RETURN_IF_ERROR(Init());
2059 call_graph_ = CallGraph::Build(module);
2060 auto computations = module->computations();
2061
2062 // Add copy to the operand of Send instructions, since we cannot call
2063 // SetOperandLayout on Send instructions as it aliases its input to the
2064 // output.
2065 //
2066 // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2067 // operand buffers that aliases with the output.
2068 for (HloComputation* computation : module->computations()) {
2069 for (HloInstruction* instruction :
2070 computation->MakeInstructionPostOrder()) {
2071 if (instruction->opcode() == HloOpcode::kSend) {
2072 TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2073 }
2074 }
2075 }
2076
2077 // Clone Conditional computations with multiple callsites.
2078 for (HloComputation* computation : computations) {
2079 CallGraphNode& node = call_graph_->GetNode(computation);
2080 if (node.caller_callsites().size() == 1) {
2081 continue;
2082 }
2083 if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2084 return caller.instruction()->opcode() == HloOpcode::kConditional;
2085 })) {
2086 continue;
2087 }
2088 for (int64 i = 0; i < node.caller_callsites().size() - 1; ++i) {
2089 HloInstruction* caller = node.caller_callsites()[i].instruction();
2090 if (caller->opcode() == HloOpcode::kConditional) {
2091 for (int64 k = 0; k < caller->branch_count(); ++k) {
2092 if (computation == caller->branch_computation(k)) {
2093 caller->set_branch_computation(
2094 k, module->AddEmbeddedComputation(computation->Clone()));
2095 break;
2096 }
2097 }
2098 }
2099 }
2100 }
2101
2102 // Verify computation layout is sane.
2103 const HloComputation* entry = module->entry_computation();
2104 TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2105 entry->num_parameters());
2106 for (int64 i = 0; i < entry->num_parameters(); ++i) {
2107 TF_RET_CHECK(
2108 ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2109 entry->parameter_instruction(i)->shape()));
2110 }
2111 TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2112 entry->root_instruction()->shape()));
2113
2114 // We do two passes. The first one we pass a nullptr ComputationLayout to
2115 // the RunOnComputation() calls (for non entry computations), and we register
2116 // the ComputationLayout which are naturally flowing in DFS fashion to the
2117 // parameters and root instruction.
2118 // Walking in DFS mode though, means that we can end up with incorrect layouts
2119 // when seen from an outer instruction, which has across-computation
2120 // constraints to impose.
2121 // For example, the kWhile instruction needs to enforce the same layouts for
2122 // the parameters and root of the body, as well as the condition parameters.
2123 // Similarly, the kConditional instruction needs to enforce the same layouts
2124 // for the root of the true and false computations.
2125 // So in the first pass, while allowing the layouts to flow to parameters and
2126 // root, we also fix up the eventually inconsistent ComputationLayout, which
2127 // will be then made mandatory by the second pass.
2128 for (int64 i = 0; i < 2; ++i) {
2129 VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2130 TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
2131 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2132 TuplePointsToAnalysis::Run(module));
2133 points_to_analysis_ = std::move(points_to_analysis);
2134 for (auto* computation : module->MakeComputationPostOrder()) {
2135 if (computation->IsFusionComputation()) {
2136 continue;
2137 }
2138 if (computation == module->entry_computation()) {
2139 TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_,
2140 module->entry_computation(),
2141 channel_layout_constraints_));
2142 } else {
2143 ComputationLayout* computation_layout =
2144 (i == 0 || conditional_mismatch_.count(computation) > 0)
2145 ? nullptr
2146 : &FindOrDie(computation_layouts_, computation);
2147 TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation,
2148 channel_layout_constraints_));
2149 }
2150 }
2151 }
2152 TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2153 entry_computation_layout_));
2154
2155 TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2156
2157 TF_RETURN_IF_ERROR(CheckLayouts(module));
2158
2159 // All layouts are reset then reassigned by this pass.
2160 return true;
2161 }
2162
2163 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2164 bool LayoutAssignment::InstructionCanChangeLayout(
2165 const HloInstruction* instruction) {
2166 switch (instruction->opcode()) {
2167 case HloOpcode::kAbs:
2168 case HloOpcode::kAdd:
2169 case HloOpcode::kAddDependency:
2170 case HloOpcode::kAnd:
2171 case HloOpcode::kAtan2:
2172 case HloOpcode::kBitcastConvert:
2173 case HloOpcode::kCeil:
2174 case HloOpcode::kClamp:
2175 case HloOpcode::kClz:
2176 case HloOpcode::kCompare:
2177 case HloOpcode::kComplex:
2178 case HloOpcode::kConcatenate:
2179 case HloOpcode::kConditional:
2180 case HloOpcode::kConvert:
2181 case HloOpcode::kCos:
2182 case HloOpcode::kAllToAll:
2183 case HloOpcode::kCollectivePermute:
2184 case HloOpcode::kDivide:
2185 case HloOpcode::kDynamicSlice:
2186 case HloOpcode::kDynamicUpdateSlice:
2187 case HloOpcode::kExp:
2188 case HloOpcode::kExpm1:
2189 case HloOpcode::kFft:
2190 case HloOpcode::kFloor:
2191 case HloOpcode::kImag:
2192 case HloOpcode::kIsFinite:
2193 case HloOpcode::kLog:
2194 case HloOpcode::kLog1p:
2195 case HloOpcode::kMap:
2196 case HloOpcode::kMaximum:
2197 case HloOpcode::kMinimum:
2198 case HloOpcode::kMultiply:
2199 case HloOpcode::kNegate:
2200 case HloOpcode::kNot:
2201 case HloOpcode::kOr:
2202 case HloOpcode::kXor:
2203 case HloOpcode::kPad:
2204 case HloOpcode::kPower:
2205 case HloOpcode::kReal:
2206 case HloOpcode::kReducePrecision:
2207 case HloOpcode::kReduceWindow:
2208 case HloOpcode::kRemainder:
2209 case HloOpcode::kReverse:
2210 case HloOpcode::kRoundNearestAfz:
2211 case HloOpcode::kRsqrt:
2212 case HloOpcode::kScatter:
2213 case HloOpcode::kSelect:
2214 case HloOpcode::kSelectAndScatter:
2215 case HloOpcode::kShiftLeft:
2216 case HloOpcode::kShiftRightArithmetic:
2217 case HloOpcode::kShiftRightLogical:
2218 case HloOpcode::kSign:
2219 case HloOpcode::kSin:
2220 case HloOpcode::kSlice:
2221 case HloOpcode::kSort:
2222 case HloOpcode::kSqrt:
2223 case HloOpcode::kSubtract:
2224 case HloOpcode::kTanh:
2225 case HloOpcode::kPopulationCount:
2226 case HloOpcode::kTriangularSolve:
2227 case HloOpcode::kCholesky:
2228 case HloOpcode::kTupleSelect:
2229 case HloOpcode::kWhile:
2230 case HloOpcode::kSetDimensionSize:
2231 // AllReduce is variadic so it needs to be careful to assign the same layout
2232 // to the corresponding input argument and Tuple index.
2233 case HloOpcode::kAllReduce:
2234 return false;
2235 case HloOpcode::kBatchNormGrad:
2236 case HloOpcode::kBatchNormInference:
2237 case HloOpcode::kBatchNormTraining:
2238 case HloOpcode::kBitcast:
2239 case HloOpcode::kBroadcast:
2240 case HloOpcode::kCall:
2241 case HloOpcode::kConstant:
2242 case HloOpcode::kConvolution:
2243 case HloOpcode::kCopy:
2244 case HloOpcode::kCopyStart:
2245 case HloOpcode::kCopyDone:
2246 case HloOpcode::kCustomCall:
2247 case HloOpcode::kDomain:
2248 case HloOpcode::kDot:
2249 case HloOpcode::kFusion:
2250 case HloOpcode::kGather:
2251 case HloOpcode::kGetTupleElement:
2252 case HloOpcode::kInfeed:
2253 case HloOpcode::kIota:
2254 case HloOpcode::kOutfeed:
2255 case HloOpcode::kParameter:
2256 case HloOpcode::kPartitionId:
2257 case HloOpcode::kRecv:
2258 case HloOpcode::kRecvDone:
2259 case HloOpcode::kReduce:
2260 case HloOpcode::kReplicaId:
2261 case HloOpcode::kReshape:
2262 case HloOpcode::kRng:
2263 case HloOpcode::kRngGetAndUpdateState:
2264 case HloOpcode::kSend:
2265 case HloOpcode::kSendDone:
2266 case HloOpcode::kAfterAll:
2267 case HloOpcode::kTrace:
2268 case HloOpcode::kTranspose:
2269 case HloOpcode::kTuple:
2270 case HloOpcode::kGetDimensionSize:
2271 return true;
2272 }
2273 }
2274
2275 /* static */
IsAtMostRank1(const Shape & shape)2276 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2277 if (shape.IsArray()) {
2278 return shape.rank() <= 1;
2279 }
2280 return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2281 return IsAtMostRank1(subshape);
2282 });
2283 }
2284
Init()2285 Status LayoutAssignment::Init() {
2286 computation_layouts_.clear();
2287 conditional_mismatch_.clear();
2288 *entry_computation_layout_ = saved_entry_computation_layout_;
2289 return Status::OK();
2290 }
2291
ClearPreviousPassSideEffects(HloModule * module)2292 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2293 VLOG(5) << "Clearing previous side effects";
2294 // Clear all the copies which have been added, and all the related
2295 // instructions (like GTE and tuples).
2296 int64 removed_copies = 0;
2297 for (HloComputation* computation : module->computations()) {
2298 for (HloInstruction* instruction :
2299 computation->MakeInstructionPostOrder()) {
2300 if (instruction->opcode() == HloOpcode::kCopy &&
2301 added_copies_.contains(instruction)) {
2302 VLOG(5) << "Removing added copy: " << instruction->ToString();
2303 TF_RETURN_IF_ERROR(
2304 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2305 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2306 ++removed_copies;
2307 }
2308 }
2309 }
2310 added_copies_.clear();
2311 unconstrained_layout_instructions_.clear();
2312 if (removed_copies > 0) {
2313 TupleSimplifier tuple_simplifier;
2314 HloDCE dce;
2315 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2316 TF_RETURN_IF_ERROR(dce.Run(module).status());
2317 }
2318 return Status::OK();
2319 }
2320
AddCopyForOperand(HloInstruction * instruction,int64 operand_number)2321 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2322 int64 operand_number) {
2323 HloInstruction* operand = instruction->mutable_operand(operand_number);
2324 if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2325 HloInstruction* copy =
2326 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2327 operand->shape(), HloOpcode::kCopy, operand));
2328 SetupCopiedInstruction(*operand, copy, {});
2329 LayoutUtil::ClearLayout(copy->mutable_shape());
2330 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2331 }
2332 return Status::OK();
2333 }
2334
2335 } // namespace xla
2336