1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <numeric>
24 #include <ostream>
25 #include <set>
26 #include <string>
27 #include <tuple>
28
29 #include "absl/algorithm/container.h"
30 #include "absl/memory/memory.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/map_util.h"
37 #include "tensorflow/compiler/xla/permutation_util.h"
38 #include "tensorflow/compiler/xla/service/call_graph.h"
39 #include "tensorflow/compiler/xla/service/computation_layout.h"
40 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_dce.h"
44 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
45 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/logical_buffer.h"
48 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
49 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
50 #include "tensorflow/compiler/xla/shape_layout.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/status_macros.h"
53 #include "tensorflow/compiler/xla/statusor.h"
54 #include "tensorflow/compiler/xla/types.h"
55 #include "tensorflow/compiler/xla/util.h"
56 #include "tensorflow/compiler/xla/xla_data.pb.h"
57 #include "tensorflow/core/lib/core/errors.h"
58 #include "tensorflow/core/lib/core/status.h"
59 #include "tensorflow/core/platform/logging.h"
60 #include "tensorflow/core/platform/protobuf.h"
61
62 namespace xla {
63
operator <<(std::ostream & out,const LayoutConstraint & constraint)64 std::ostream& operator<<(std::ostream& out,
65 const LayoutConstraint& constraint) {
66 out << constraint.ToString();
67 return out;
68 }
69
BufferLayoutConstraint(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs,int64_t priority)70 BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
71 const LogicalBuffer& buffer,
72 bool mandatory, bool dfs,
73 int64_t priority)
74 : LayoutConstraint(mandatory, dfs, priority),
75 layout_(layout),
76 buffer_(&buffer) {
77 CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
78 }
79
ToString() const80 string BufferLayoutConstraint::ToString() const {
81 return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(),
82 LayoutUtil::HumanString(layout_));
83 }
84
OperandLayoutConstraint(const ShapeLayout & shape_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs,int64_t priority)85 OperandLayoutConstraint::OperandLayoutConstraint(
86 const ShapeLayout& shape_layout, const HloInstruction* instruction,
87 int64_t operand_no, bool mandatory, bool dfs, int64_t priority)
88 : LayoutConstraint(mandatory, dfs, priority),
89 shape_layout_(shape_layout),
90 instruction_(instruction),
91 operand_no_(operand_no) {
92 CHECK(shape_layout_.LayoutIsSet());
93 CHECK(ShapeUtil::Compatible(shape_layout.shape(),
94 instruction->operand(operand_no)->shape()))
95 << shape_layout.shape() << " is not compatible with "
96 << instruction->operand(operand_no)->shape() << " (for operand "
97 << operand_no << " of instruction " << instruction->ToString() << ")";
98 }
99
ToString() const100 string OperandLayoutConstraint::ToString() const {
101 return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s",
102 instruction_->name(), operand_no_,
103 shape_layout_.ToString());
104 }
105
ToString() const106 string ResultLayoutConstraint::ToString() const {
107 return absl::StrFormat("ResultLayoutConstraint: %s",
108 shape_layout_.ToString());
109 }
110
LayoutConstraints(const TuplePointsToAnalysis & points_to_analysis,HloComputation * computation)111 LayoutConstraints::LayoutConstraints(
112 const TuplePointsToAnalysis& points_to_analysis,
113 HloComputation* computation)
114 : points_to_analysis_(points_to_analysis), computation_(computation) {
115 // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
116 for (HloInstruction* inst : computation_->instructions()) {
117 points_to_analysis_.GetPointsToSet(inst).ForEachElement(
118 [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
119 for (const LogicalBuffer* buffer : buffers) {
120 // The points to analysis is computed per module, restrict
121 // constraints to array buffers in this computation.
122 if (buffer->IsArray() &&
123 buffer->instruction()->parent() == computation) {
124 unconstrained_buffer_ids_.insert(buffer->id());
125 }
126 }
127 });
128 }
129 }
130
GetBufferSet(const HloInstruction * instruction) const131 PointsToSet::BufferSet* LayoutConstraints::GetBufferSet(
132 const HloInstruction* instruction) const {
133 auto it = buffer_sets_cache_.find(instruction);
134 if (it != buffer_sets_cache_.end()) {
135 return it->second.get();
136 }
137 auto& buffer_set =
138 buffer_sets_cache_
139 .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>())
140 .first->second;
141 const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction);
142 points_to_set.ForEachElement(
143 [&buffer_set](const ShapeIndex& /*index*/,
144 const PointsToSet::BufferList& buffers) {
145 buffer_set->insert(buffers.begin(), buffers.end());
146 });
147 return buffer_set.get();
148 }
149
AnyOperandBufferForwarded(const HloInstruction * instruction,int64_t operand_no) const150 bool LayoutConstraints::AnyOperandBufferForwarded(
151 const HloInstruction* instruction, int64_t operand_no) const {
152 // The operand is potentially forwarded if the intersection of points-to sets
153 // of the operand and the instruction is non-empty.
154 PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
155 PointsToSet::BufferSet* operand_buffers =
156 GetBufferSet(instruction->operand(operand_no));
157 return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
158 return operand_buffers->count(b) > 0;
159 });
160 }
161
AllOperandBuffersForwarded(const HloInstruction * instruction,int64_t operand_no) const162 bool LayoutConstraints::AllOperandBuffersForwarded(
163 const HloInstruction* instruction, int64_t operand_no) const {
164 // The operand is potentially forwarded if the intersection of points-to sets
165 // of the operand and the instruction is non-empty.
166 PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
167 PointsToSet::BufferSet* operand_buffers =
168 GetBufferSet(instruction->operand(operand_no));
169 return absl::c_all_of(*output_buffers, [&](const LogicalBuffer* b) {
170 return operand_buffers->count(b) > 0;
171 });
172 }
173
SetBufferLayout(const Layout & layout,const LogicalBuffer & buffer,bool mandatory,bool dfs)174 Status LayoutConstraints::SetBufferLayout(const Layout& layout,
175 const LogicalBuffer& buffer,
176 bool mandatory, bool dfs) {
177 VLOG(3) << "SetBufferLayout : " << buffer << " : "
178 << LayoutUtil::HumanString(layout);
179
180 TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
181 if (!buffer.IsArray()) {
182 return FailedPrecondition(
183 "Layout of buffer %s cannot be constrained because buffer is not "
184 "array-shaped, has shape: %s",
185 buffer.ToString(), ShapeUtil::HumanString(buffer.shape()));
186 }
187 TF_RETURN_IF_ERROR(
188 LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
189
190 auto iter = buffer_constraints_.find(&buffer);
191 if (iter != buffer_constraints_.end()) {
192 const BufferLayoutConstraint& curr_constraint = iter->second;
193 if (Layout::Equal().MinorToMajorOnly()(curr_constraint.layout(), layout)) {
194 // New constraint matches existing constraint. Nothing to do.
195 return Status::OK();
196 }
197 if (curr_constraint.mandatory()) {
198 if (!mandatory) {
199 VLOG(3) << "Buffer" << buffer
200 << " already has a mandatory layout constrain, skipping";
201 return Status::OK();
202 }
203 return FailedPrecondition(
204 "Buffer %s already has the layout constraint %s, cannot add "
205 "incompatible constraint %s",
206 buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()),
207 LayoutUtil::HumanString(layout));
208 }
209 iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
210 } else {
211 TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
212 << buffer.ToString();
213 iter = buffer_constraints_
214 .insert(std::make_pair(
215 &buffer,
216 BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
217 .first;
218 }
219 added_constraints_.push_back(&iter->second);
220 return Status::OK();
221 }
222
SetOperandLayout(const Shape & shape_with_layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs)223 Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
224 const HloInstruction* instruction,
225 int64_t operand_no, bool mandatory,
226 bool dfs) {
227 auto priority = LayoutConstraint::kDefaultPriority;
228 // The second and third operands (operand_no > 0) of a dynamic-update-slice
229 // operation typically have much smaller sizes than the first (operand_no==0)
230 // operand. It is necessary to downgrade the importance of the smaller
231 // operands, so that the overall layout choice of the operation is dicated by
232 // operand 0 when possible.
233 if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice &&
234 operand_no > 0 && !mandatory) {
235 dfs = false;
236 priority--;
237 }
238 VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
239 << operand_no << " : "
240 << ShapeUtil::HumanStringWithLayout(shape_with_layout)
241 << " : priority = " << priority << "\n";
242 const OperandLayoutConstraint* curr_shape_layout =
243 GetOperandLayoutConstraint(instruction, operand_no);
244 if (curr_shape_layout != nullptr) {
245 if (curr_shape_layout->shape_layout().MatchesLayoutInShape(
246 shape_with_layout, /*minor_to_major_only=*/true)) {
247 // New constraint matches existing constraint. Nothing to do.
248 return Status::OK();
249 }
250 if (curr_shape_layout->mandatory()) {
251 return FailedPrecondition(
252 "Operand %d of instruction %s already has a layout constraint "
253 "%s, cannot add incompatible constraint %s",
254 operand_no, instruction->name(),
255 curr_shape_layout->shape_layout().ToString(),
256 ShapeUtil::HumanStringWithLayout(shape_with_layout));
257 }
258 }
259
260 // If any buffers in the operand occur in the output of the instruction, then
261 // return an error. This case is not handled because such a constraint changes
262 // layouts beyond this immediate use and is complicated to handle.
263 if (AnyOperandBufferForwarded(instruction, operand_no)) {
264 return FailedPrecondition(
265 "Cannot constraint layout of operand %d of instruction %s "
266 "because instruction forwards operand's LogicalBuffer(s)",
267 operand_no, instruction->name());
268 }
269
270 auto key = std::make_pair(instruction, operand_no);
271 auto iter = operand_constraints_.find(key);
272 if (iter == operand_constraints_.end()) {
273 auto pair = std::make_pair(
274 key,
275 OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
276 operand_no, mandatory, dfs, priority));
277 iter = operand_constraints_.insert(pair).first;
278 } else {
279 iter->second =
280 OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
281 operand_no, mandatory, dfs, priority);
282 }
283 added_constraints_.push_back(&iter->second);
284 // Move the new constraint from the end of the worklist forward if its
285 // priority is higher than those ahead of it in the worklist. Assuming the
286 // worklist is already sorted, this works like a bubble sort, where the lower
287 // priority constraints are always pushed to the tail of the worklist.
288 for (int64_t i = added_constraints_.size() - 2; i >= 0; --i) {
289 if (added_constraints_[i]->priority() <
290 added_constraints_[i + 1]->priority()) {
291 std::swap(added_constraints_[i + 1], added_constraints_[i]);
292 } else {
293 break;
294 }
295 }
296 return Status::OK();
297 }
298
SetArrayOperandLayout(const Layout & layout,const HloInstruction * instruction,int64_t operand_no,bool mandatory,bool dfs)299 Status LayoutConstraints::SetArrayOperandLayout(
300 const Layout& layout, const HloInstruction* instruction, int64_t operand_no,
301 bool mandatory, bool dfs) {
302 const HloInstruction* operand = instruction->operand(operand_no);
303 TF_RET_CHECK(operand->shape().IsArray());
304 Shape shape(operand->shape());
305 *shape.mutable_layout() = layout;
306 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
307 return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs);
308 }
309
SetResultLayout(const Shape & shape_with_layout,bool dfs)310 Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout,
311 bool dfs) {
312 VLOG(3) << "SetResultLayout : "
313 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
314
315 const ShapeLayout* curr_shape_layout = ResultLayout();
316 if (curr_shape_layout != nullptr) {
317 if (!curr_shape_layout->MatchesLayoutInShape(
318 shape_with_layout, /*minor_to_major_only=*/true)) {
319 return FailedPrecondition(
320 "Result of computation %s already has the layout constraint %s, "
321 "cannot add incompatible constraint %s",
322 computation_->name(), curr_shape_layout->ToString(),
323 ShapeUtil::HumanStringWithLayout(shape_with_layout));
324 }
325 // New constraint matches existing constraint. Nothing to do.
326 return Status::OK();
327 }
328 result_constraint_.reset(
329 new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs));
330 added_constraints_.push_back(result_constraint_.get());
331
332 return Status::OK();
333 }
334
SetInstructionLayout(const Shape & shape_with_layout,const HloInstruction * instruction,bool mandatory,bool dfs,bool allow_alias)335 Status LayoutConstraints::SetInstructionLayout(
336 const Shape& shape_with_layout, const HloInstruction* instruction,
337 bool mandatory, bool dfs, bool allow_alias) {
338 VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
339 << ShapeUtil::HumanStringWithLayout(shape_with_layout);
340
341 if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
342 return FailedPrecondition(
343 "Instruction %s of shape %s cannot be assigned incompatible layout %s",
344 instruction->name(), ShapeUtil::HumanString(instruction->shape()),
345 ShapeUtil::HumanStringWithLayout(shape_with_layout));
346 }
347
348 // Create a BufferLayoutConstraint for each array shape in the output of the
349 // instruction.
350 return ShapeUtil::ForEachSubshapeWithStatus(
351 shape_with_layout,
352 [this, instruction, mandatory, allow_alias](
353 const Shape& subshape, const ShapeIndex& index) -> Status {
354 auto buffers =
355 points_to_analysis_.GetPointsToSet(instruction).element(index);
356 CHECK_EQ(1, buffers.size());
357 if (!allow_alias) {
358 CHECK_EQ(buffers[0]->instruction(), instruction);
359 }
360
361 if (subshape.IsArray() && subshape.has_layout()) {
362 return SetBufferLayout(subshape.layout(), *buffers[0], mandatory);
363 } else {
364 return Status::OK();
365 }
366 });
367 }
368
BufferLayout(const LogicalBuffer & buffer) const369 const Layout* LayoutConstraints::BufferLayout(
370 const LogicalBuffer& buffer) const {
371 if (const auto* constraint = GetBufferLayoutConstraint(buffer)) {
372 return &constraint->layout();
373 }
374 return nullptr;
375 }
376
GetBufferLayoutConstraint(const LogicalBuffer & buffer) const377 const BufferLayoutConstraint* LayoutConstraints::GetBufferLayoutConstraint(
378 const LogicalBuffer& buffer) const {
379 auto it = buffer_constraints_.find(&buffer);
380 return it == buffer_constraints_.end() ? nullptr : &it->second;
381 }
382
OperandLayout(const HloInstruction * instruction,int64_t operand_no) const383 const ShapeLayout* LayoutConstraints::OperandLayout(
384 const HloInstruction* instruction, int64_t operand_no) const {
385 if (const auto* constraint =
386 GetOperandLayoutConstraint(instruction, operand_no)) {
387 return &constraint->shape_layout();
388 }
389 return nullptr;
390 }
391
GetOperandLayoutConstraint(const HloInstruction * instruction,int64_t operand_no) const392 const OperandLayoutConstraint* LayoutConstraints::GetOperandLayoutConstraint(
393 const HloInstruction* instruction, int64_t operand_no) const {
394 auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
395 return it == operand_constraints_.end() ? nullptr : &it->second;
396 }
397
ResultLayout() const398 const ShapeLayout* LayoutConstraints::ResultLayout() const {
399 return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
400 }
401
ToString() const402 string LayoutConstraints::ToString() const {
403 string output;
404 absl::StrAppend(&output, "LayoutConstraints for computation ",
405 computation_->name(), ":\n");
406 for (auto* instruction : computation_->MakeInstructionPostOrder()) {
407 absl::StrAppend(&output, " ", instruction->ToShortString(), "\n");
408 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
409 if (OperandLayout(instruction, i) != nullptr) {
410 absl::StrAppend(&output, " operand (", i,
411 "): ", OperandLayout(instruction, i)->ToString(), "\n");
412 }
413 }
414 for (const LogicalBuffer* buffer :
415 points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
416 if (BufferLayout(*buffer) != nullptr) {
417 absl::StrAppend(&output, " ", buffer->ToString(), " : ",
418 LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
419 }
420 }
421 }
422
423 if (ResultLayout() != nullptr) {
424 absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n");
425 }
426 return output;
427 }
428
429 namespace {
430
IsHostSendRecv(const HloInstruction * instruction)431 bool IsHostSendRecv(const HloInstruction* instruction) {
432 const HloSendRecvInstruction* send_recv_instr =
433 DynCast<HloSendRecvInstruction>(instruction);
434 return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
435 }
436
437 } // namespace
438
BuildHostChannelConstraints(HloComputation * computation)439 Status LayoutAssignment::BuildHostChannelConstraints(
440 HloComputation* computation) {
441 for (auto* instruction : computation->instructions()) {
442 const HloSendRecvInstruction* send_recv_instr =
443 DynCast<HloSendRecvInstruction>(instruction);
444 if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
445 continue;
446 }
447
448 // For host transfers the Send and Recv instruction carry the layout.
449 if (instruction->opcode() == HloOpcode::kSend ||
450 instruction->opcode() == HloOpcode::kRecv) {
451 const Shape& data_shape =
452 ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
453 TF_RET_CHECK(data_shape.IsArray());
454 TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
455 const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
456 *send_recv_instr->channel_id(), data_shape.layout());
457 TF_RET_CHECK(prev_layout == nullptr)
458 << "Cannot constrain host transfer layout as it was set to "
459 << LayoutUtil::HumanString(*prev_layout) << ": "
460 << send_recv_instr->ToString();
461 }
462 }
463 return Status::OK();
464 }
465
466 namespace {
467
IsLayoutConstrainedCustomCall(HloInstruction * instruction)468 bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
469 const HloCustomCallInstruction* custom_call =
470 DynCast<HloCustomCallInstruction>(instruction);
471 return custom_call != nullptr && custom_call->layout_constrained();
472 }
473
IsLayoutConstrainedCollective(const HloInstruction * instruction)474 bool IsLayoutConstrainedCollective(const HloInstruction* instruction) {
475 const HloCollectiveInstruction* collective =
476 DynCast<HloCollectiveInstruction>(instruction);
477 return collective != nullptr && collective->constrain_layout();
478 }
479
PropagateParameterLayoutToUsers(const HloInstruction * instruction,const Shape & shape,LayoutConstraints * constraints)480 Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
481 const Shape& shape,
482 LayoutConstraints* constraints) {
483 for (auto* user : instruction->users()) {
484 // Excluding tuple operations as they do not participate in layout
485 // propagations (they do not create or aliase buffers).
486 if (user->opcode() == HloOpcode::kTuple) {
487 continue;
488 }
489 VLOG(3) << "Setting user layout : " << user->ToString();
490 if (user->opcode() == HloOpcode::kGetTupleElement) {
491 auto tuple_index = user->tuple_index();
492 CHECK(shape.IsTuple());
493 auto elem_shape = shape.tuple_shapes(tuple_index);
494 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
495 elem_shape, user, /*mandatory=*/false, /*dfs=*/false,
496 /*allow_alias=*/true));
497 TF_RETURN_IF_ERROR(
498 PropagateParameterLayoutToUsers(user, elem_shape, constraints));
499 } else {
500 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
501 shape, user, user->operand_index(instruction), /*mandatory=*/false,
502 /*dfs=*/false));
503 }
504 }
505 return Status::OK();
506 }
507
508 } // namespace
509
AddMandatoryConstraints(const ComputationLayout * computation_layout,ChannelLayoutConstraints * channel_constraints,HloComputation * computation,LayoutConstraints * constraints)510 Status LayoutAssignment::AddMandatoryConstraints(
511 const ComputationLayout* computation_layout,
512 ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
513 LayoutConstraints* constraints) {
514 VLOG(2) << "Adding mandatory layout constraints to computation "
515 << computation->name();
516
517 auto get_channel_constraints = [&](const HloInstruction* instruction) {
518 return IsHostSendRecv(instruction) ? &host_channel_constraints_
519 : channel_constraints;
520 };
521
522 // Constrain layouts of instructions which define values with pre-existing
523 // layouts.
524 for (auto* instruction : computation->instructions()) {
525 if (instruction->opcode() == HloOpcode::kInfeed) {
526 // Infeed layouts must match the layout of the original inserted
527 // instruction.
528 // TODO(b/31425034): Change infeeds to be more like parameters, with
529 // shapes in the ComputationLayout.
530 TF_RETURN_IF_ERROR(
531 constraints->SetInstructionLayout(instruction->shape(), instruction));
532 } else if (instruction->opcode() == HloOpcode::kOutfeed) {
533 // Constrain the input to the Outfeed instruction to be the expected
534 // layout of the Outfeed.
535 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
536 instruction->outfeed_shape(), instruction, 0));
537 } else if (instruction->opcode() == HloOpcode::kParameter) {
538 if (computation_layout != nullptr) {
539 const ShapeLayout& parameter_layout =
540 computation_layout->parameter_layout(
541 instruction->parameter_number());
542 // Parameter layouts must match the respective layout in
543 // ComputationLayout, if there is one.
544 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
545 parameter_layout.shape(), instruction));
546 if (reverse_computation_order_) {
547 TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers(
548 instruction, parameter_layout.shape(), constraints));
549 }
550 }
551 } else if (IsLayoutConstrainedCustomCall(instruction)) {
552 const HloCustomCallInstruction* custom_call =
553 DynCast<HloCustomCallInstruction>(instruction);
554
555 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
556 custom_call->shape(), custom_call, /*mandatory=*/true, /*dfs=*/true,
557 /*allow_alias=*/true));
558
559 for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
560 if (constraints->AnyOperandBufferForwarded(custom_call, i)) {
561 TF_RET_CHECK(constraints->AllOperandBuffersForwarded(custom_call, i))
562 << "Partial alias of an operand is not supported";
563 } else {
564 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
565 custom_call->operand_shapes_with_layout()[i], custom_call, i));
566 }
567 }
568 } else if (instruction->opcode() == HloOpcode::kSend ||
569 instruction->opcode() == HloOpcode::kRecv) {
570 CHECK(get_channel_constraints(instruction))
571 << "Multi-module layout assignment requires ChannelLayoutConstraints";
572 int64_t channel_id = *instruction->channel_id();
573 if (!get_channel_constraints(instruction)
574 ->IsChannelConstrained(channel_id)) {
575 continue;
576 }
577 if (instruction->opcode() == HloOpcode::kSend) {
578 // TODO(b/68493863): Change to use SetOperandLayout().
579 const Shape send_buffer_shape = instruction->operand(0)->shape();
580 TF_RET_CHECK(send_buffer_shape.IsArray());
581 Shape new_buffer_shape =
582 get_channel_constraints(instruction)
583 ->LayoutShapeForChannel(send_buffer_shape,
584 *instruction->channel_id());
585 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
586 new_buffer_shape, instruction->operand(0)));
587 } else {
588 const Shape recv_buffer_shape =
589 ShapeUtil::GetTupleElementShape(instruction->shape(), 0);
590 TF_RET_CHECK(recv_buffer_shape.IsArray());
591 TF_ASSIGN_OR_RETURN(
592 const LogicalBuffer* buffer,
593 constraints->points_to_analysis().GetBufferDefinedAt(instruction,
594 {0}));
595 Shape new_shape =
596 get_channel_constraints(instruction)
597 ->LayoutShapeForChannel(recv_buffer_shape,
598 *instruction->channel_id());
599 TF_RETURN_IF_ERROR(
600 constraints->SetBufferLayout(new_shape.layout(), *buffer));
601 }
602 } else if (IsLayoutConstrainedCollective(instruction)) {
603 TF_RETURN_IF_ERROR(
604 constraints->SetInstructionLayout(instruction->shape(), instruction));
605 } else if (instruction->IsCrossModuleAllReduce()) {
606 CHECK(get_channel_constraints(instruction))
607 << "Multi-module layout assignment requires ChannelLayoutConstraints";
608 int64_t channel_id = instruction->channel_id().value();
609 if (!get_channel_constraints(instruction)
610 ->IsChannelConstrained(channel_id)) {
611 continue;
612 }
613 // TODO(b/68493863): Change to use SetOperandLayout().
614 const Shape& buffer_shape = instruction->operand(0)->shape();
615 TF_RET_CHECK(buffer_shape.IsArray());
616 Shape new_buffer_shape =
617 get_channel_constraints(instruction)
618 ->LayoutShapeForChannel(buffer_shape, channel_id);
619 TF_RETURN_IF_ERROR(
620 constraints->SetInstructionLayout(new_buffer_shape, instruction));
621 }
622 }
623
624 // Constrain layouts of instructions which call computations which have
625 // already been assigned layouts. Instructions which call computations in a
626 // parallel element-wise context (eg, map or reduce) do not need layout
627 // constraints because they operate on scalars.
628 for (auto* instruction : computation->instructions()) {
629 if (instruction->opcode() == HloOpcode::kCall &&
630 computation_layouts_.find(instruction->to_apply()) !=
631 computation_layouts_.end()) {
632 // kCall instruction operands and output must match the ComputationLayout
633 // of the called computation.
634 const ComputationLayout& called_computation_layout =
635 FindOrDie(computation_layouts_, instruction->to_apply());
636 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
637 called_computation_layout.result_layout().shape(), instruction));
638 TF_RET_CHECK(instruction->operand_count() ==
639 called_computation_layout.parameter_count());
640 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
641 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
642 called_computation_layout.parameter_layout(i).shape(), instruction,
643 i));
644 }
645 } else if (instruction->opcode() == HloOpcode::kWhile &&
646 computation_layouts_.find(instruction->while_body()) !=
647 computation_layouts_.end()) {
648 // Layout of input and output of kWhile instruction must be equal and must
649 // match both input and output of body computation. Also, the input of
650 // condition computation must match kWhile layout.
651 HloComputation* body = instruction->while_body();
652 HloComputation* condition = instruction->while_condition();
653 const HloInstruction* init = instruction->operand(0);
654 ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
655 ComputationLayout& condition_layout =
656 FindOrDie(computation_layouts_, condition);
657
658 // Check a few invariants irrespective of layout.
659 CHECK_EQ(1, instruction->operand_count());
660 CHECK_EQ(1, body->num_parameters());
661 CHECK_EQ(1, condition->num_parameters());
662 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
663 body_layout.parameter_shape(0)));
664 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
665 condition_layout.parameter_shape(0)));
666 DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
667
668 if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
669 VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
670 << " while=" << instruction->name()
671 << " shape=" << body_layout.result_layout().ToString();
672 *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
673 }
674 if (condition_layout.parameter_layout(0) !=
675 body_layout.parameter_layout(0)) {
676 VLOG(2) << "Reset %while condition parameter layout: cond="
677 << condition->name() << " while=" << instruction->name()
678 << " shape=" << body_layout.parameter_layout(0).ToString();
679 *condition_layout.mutable_parameter_layout(0) =
680 body_layout.parameter_layout(0);
681 }
682
683 // Constrain the output and the operand of the while instruction to match
684 // the computations.
685 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
686 body_layout.result_shape(), instruction, 0));
687 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
688 body_layout.result_shape(), instruction));
689 } else if (instruction->opcode() == HloOpcode::kConditional &&
690 computation_layouts_.find(instruction->branch_computation(0)) !=
691 computation_layouts_.end()) {
692 // Find the conditional branch with the most instructions and force all
693 // other computations to match that layout. A potentially better decision
694 // could count the number FLOPs or how constrained the layouts are.
695 int64_t largest_branch = -1;
696 int64_t largest_instruction_count = 0;
697 for (int j = 0; j < instruction->branch_count(); ++j) {
698 const int64_t instruction_count =
699 instruction->branch_computation(j)->instruction_count();
700 if (instruction_count > largest_instruction_count &&
701 !ShapeUtil::IsEmptyTuple(instruction->operand(j + 1)->shape())) {
702 largest_branch = j;
703 largest_instruction_count = instruction_count;
704 }
705 }
706 if (largest_branch == -1) {
707 largest_branch = 0;
708 }
709 ComputationLayout& best_branch_computation_layout =
710 FindOrDie(computation_layouts_,
711 instruction->branch_computation(largest_branch));
712 for (int k = 0; k < instruction->branch_count(); ++k) {
713 // Visit the best branch first.
714 int j = (k + largest_branch) % instruction->branch_count();
715 TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1);
716 ComputationLayout& branch_computation_layout =
717 FindOrDie(computation_layouts_, instruction->branch_computation(k));
718 if (!branch_computation_layout.result_layout().MatchesLayoutInShape(
719 best_branch_computation_layout.result_layout().shape(),
720 /*minor_to_major_only=*/true)) {
721 computation_layouts_.erase(instruction->branch_computation(k));
722 InsertOrDie(&conditional_mismatch_,
723 instruction->branch_computation(k),
724 best_branch_computation_layout);
725 } else {
726 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
727 branch_computation_layout.parameter_shape(0), instruction, k + 1,
728 /*mandatory=*/true));
729 }
730 }
731 TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
732 best_branch_computation_layout.parameter_shape(0), instruction,
733 largest_branch + 1,
734 /*mandatory=*/true));
735 TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
736 best_branch_computation_layout.result_shape(), instruction));
737 }
738 }
739 // Finally set the result layout to match ComputationLayout, if there is one.
740 if (conditional_mismatch_.count(computation) > 0) {
741 TF_RETURN_IF_ERROR(constraints->SetResultLayout(
742 FindOrDie(conditional_mismatch_, computation).result_layout().shape()));
743 } else if (computation_layout != nullptr) {
744 const ShapeLayout& result_layout = computation_layout->result_layout();
745 if (result_layout.LayoutIsSet()) {
746 TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
747 }
748 }
749 return Status::OK();
750 }
751
752 namespace {
753
LayoutsInShapesEqual(const Shape & lhs,const Shape & rhs)754 bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) {
755 return Layout::Equal().MinorToMajorOnly()(lhs.layout(), rhs.layout());
756 }
757
758 // The operands of a call must match the layouts of parameters in the
759 // ComputationLayout, and the call instruction itself must match the result
760 // layout in the ComputationLayout.
CheckCallLayout(HloInstruction * call,const ComputationLayout & computation_layout)761 Status CheckCallLayout(HloInstruction* call,
762 const ComputationLayout& computation_layout) {
763 HloComputation* computation = call->to_apply();
764 TF_RET_CHECK(computation->num_parameters() == call->operand_count());
765 for (int64_t i = 0; i < computation->num_parameters(); ++i) {
766 TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
767 call->operand(i)->shape(), /*minor_to_major_only=*/true));
768 }
769 TF_RET_CHECK(computation_layout.result_layout().MatchesLayoutInShape(
770 call->shape(), /*minor_to_major_only=*/true));
771 return Status::OK();
772 }
773
774 // Operands of layout-constrained custom calls must match the expected
775 // constrained layouts.
CheckCustomCallLayout(HloInstruction * instruction)776 Status CheckCustomCallLayout(HloInstruction* instruction) {
777 if (IsLayoutConstrainedCustomCall(instruction)) {
778 const HloCustomCallInstruction* custom_call =
779 DynCast<HloCustomCallInstruction>(instruction);
780 for (int64_t i = 0; i < custom_call->operand_count(); ++i) {
781 TF_RET_CHECK(
782 LayoutsInShapesEqual(custom_call->operand(i)->shape(),
783 custom_call->operand_shapes_with_layout()[i]));
784 }
785 }
786 return Status::OK();
787 }
788
789 // For a while instruction, all the following layouts must be the same:
790 // (1) init operand
791 // (2) condition computation parameter
792 // (3) body computation parameter
793 // (4) body computation result
794 // (5) while instruction result
CheckWhileLayout(HloInstruction * while_inst,const ComputationLayout & condition_computation_layout,const ComputationLayout & body_computation_layout)795 Status CheckWhileLayout(HloInstruction* while_inst,
796 const ComputationLayout& condition_computation_layout,
797 const ComputationLayout& body_computation_layout) {
798 auto init_shape = while_inst->operand(0)->shape();
799 TF_RET_CHECK(
800 condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
801 init_shape, /*minor_to_major_only=*/true));
802 TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
803 init_shape, /*minor_to_major_only=*/true));
804 TF_RET_CHECK(body_computation_layout.result_layout().MatchesLayoutInShape(
805 init_shape, /*minor_to_major_only=*/true));
806 TF_RET_CHECK(LayoutsInShapesEqual(init_shape, while_inst->shape()));
807 return Status::OK();
808 }
809
CheckConditionalLayout(HloInstruction * instruction,absl::Span<const ComputationLayout> branch_computation_layouts)810 Status CheckConditionalLayout(
811 HloInstruction* instruction,
812 absl::Span<const ComputationLayout> branch_computation_layouts) {
813 for (int j = 0; j < instruction->branch_count(); ++j) {
814 const HloInstruction* branch_operand = instruction->operand(j + 1);
815 TF_RET_CHECK(
816 branch_computation_layouts[0].result_layout().MatchesLayoutInShape(
817 branch_computation_layouts[j].result_layout().shape(),
818 /*minor_to_major_only=*/true));
819 TF_RET_CHECK(
820 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
821 instruction->shape(), /*minor_to_major_only=*/true));
822 TF_RET_CHECK(
823 branch_computation_layouts[j].result_layout().MatchesLayoutInShape(
824 instruction->branch_computation(j)->root_instruction()->shape(),
825 /*minor_to_major_only=*/true));
826 TF_RET_CHECK(
827 branch_computation_layouts[j].parameter_layout(0).MatchesLayoutInShape(
828 branch_operand->shape(), /*minor_to_major_only=*/true));
829 }
830 return Status::OK();
831 }
832
833 // Fusion parameters must match the layout of the fusion instructions operands,
834 // and the root of the fusion expression must match the layout of the fusion
835 // instruction.
CheckFusionLayout(HloInstruction * fusion)836 Status CheckFusionLayout(HloInstruction* fusion) {
837 TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
838
839 TF_RET_CHECK(LayoutsInShapesEqual(fusion->shape(),
840 fusion->fused_expression_root()->shape()));
841 for (int64_t i = 0; i < fusion->operand_count(); ++i) {
842 TF_RET_CHECK(LayoutsInShapesEqual(fusion->fused_parameter(i)->shape(),
843 fusion->operand(i)->shape()));
844 }
845 return Status::OK();
846 }
847
848 // The layout of a parameter must match the respective layout in the
849 // computation's ComputationLayout.
CheckParameterLayout(HloInstruction * parameter,const ComputationLayout & computation_layout)850 Status CheckParameterLayout(HloInstruction* parameter,
851 const ComputationLayout& computation_layout) {
852 const ShapeLayout& parameter_layout =
853 computation_layout.parameter_layout(parameter->parameter_number());
854 return ShapeUtil::ForEachSubshapeWithStatus(
855 parameter_layout.shape(),
856 [&](const Shape& subshape, const ShapeIndex& shape_index) {
857 if (!ShapeUtil::IsLeafIndex(parameter_layout.shape(), shape_index) ||
858 !subshape.has_layout()) {
859 return Status::OK();
860 }
861 if (!Shape::Equal().MinorToMajorOnlyInLayout().IgnoreDynamicDimension()(
862 subshape,
863 ShapeUtil::GetSubshape(parameter->shape(), shape_index))) {
864 return InternalError(
865 "parameter instruction %s does not match layout of computation "
866 "shape: %s",
867 parameter->ToString(), parameter_layout.ToString());
868 }
869 return Status::OK();
870 });
871 }
872
873 // The layout of a constant instruction must match the layout of its literal.
CheckConstantLayout(HloInstruction * constant)874 Status CheckConstantLayout(HloInstruction* constant) {
875 if (!LayoutsInShapesEqual(constant->literal().shape(), constant->shape())) {
876 return InternalError(
877 "constant instruction %s does not match the layout of its literal %s",
878 constant->ToString(),
879 ShapeUtil::HumanStringWithLayout(constant->literal().shape()));
880 }
881 return Status::OK();
882 }
883
884 } // namespace
885
CreateCopyWithNewLayout(const Shape & shape_with_layout,HloInstruction * instruction)886 StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
887 const Shape& shape_with_layout, HloInstruction* instruction) {
888 TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
889 DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()))
890 << ShapeUtil::HumanString(shape_with_layout) << " "
891 << ShapeUtil::HumanString(instruction->shape())
892 << " instruction: " << instruction->ToString();
893
894 if (instruction->shape().IsTuple()) {
895 // Copy tuple elements which have differing layouts.
896 std::vector<HloInstruction*> element_copies;
897 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
898 ++i) {
899 const Shape& target_shape =
900 ShapeUtil::GetSubshape(shape_with_layout, {i});
901 const Shape& instr_shape =
902 ShapeUtil::GetSubshape(instruction->shape(), {i});
903 HloInstruction* gte = instruction->parent()->AddInstruction(
904 HloInstruction::CreateGetTupleElement(instr_shape, instruction, i));
905
906 if (Shape::Equal().MinorToMajorOnlyInLayout()(target_shape,
907 instr_shape)) {
908 // Shapes and layouts are equal, no need to copy.
909 element_copies.push_back(gte);
910 } else {
911 SetupCopiedInstruction(*instruction, gte, {i});
912 // Recurse to copy each element.
913 TF_ASSIGN_OR_RETURN(HloInstruction * element_copy,
914 CreateCopyWithNewLayout(target_shape, gte));
915 element_copies.push_back(element_copy);
916 }
917 }
918 // Gather element copies into a tuple with a new Tuple instruction.
919 HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
920 HloInstruction::CreateTuple(element_copies));
921 SetupCopiedInstruction(*instruction, tuple_copy, {});
922 LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
923 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
924 shape_with_layout, tuple_copy->mutable_shape()));
925 return tuple_copy;
926 } else if (instruction->shape().IsArray()) {
927 HloInstruction* copy =
928 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
929 instruction->shape(), HloOpcode::kCopy, instruction));
930 RegisterAddedCopy(copy);
931 SetupCopiedInstruction(*instruction, copy, {});
932 LayoutUtil::ClearLayout(copy->mutable_shape());
933 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
934 shape_with_layout, copy->mutable_shape()));
935
936 return copy;
937 } else {
938 return FailedPrecondition(
939 "Can only copy array and tuple shaped instructions");
940 }
941 }
942
943 // Creates a copy of the given operand if the operand's layout does not match
944 // the given layout. This copy replaces the use in the given instruction. Tuple
945 // operands will be deep-copied.
CopyOperandIfLayoutsDiffer(const ShapeLayout & operand_layout,HloInstruction * instruction,int64_t operand_no)946 Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
947 const ShapeLayout& operand_layout, HloInstruction* instruction,
948 int64_t operand_no) {
949 HloInstruction* operand = instruction->mutable_operand(operand_no);
950 TF_RET_CHECK(operand_layout.LayoutIsSet());
951 TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
952
953 if (Shape::Equal().MinorToMajorOnlyInLayout()(operand_layout.shape(),
954 operand->shape())) {
955 VLOG(2) << "Operand " << operand->ToString() << " layout matches in "
956 << instruction->ToString();
957 // Operand layout already matches our constraint. Nothing to do.
958 return Status::OK();
959 }
960 VLOG(2) << "Operand " << operand->ToString() << " layout does not match "
961 << operand_layout.ToString() << " in " << instruction->ToString();
962
963 // If the operand is only used by a conditional, do the copy inside the branch
964 // to avoid overhead for other branches.
965 if (!reverse_computation_order_ &&
966 instruction->opcode() == HloOpcode::kConditional && operand_no > 0 &&
967 instruction->operand(operand_no)->user_count() == 1) {
968 auto branch_comp = instruction->branch_computation(operand_no - 1);
969 auto param = branch_comp->parameter_instruction(0);
970 *param->mutable_shape() = operand->shape();
971 auto param_users = param->users();
972 TF_ASSIGN_OR_RETURN(HloInstruction * param_copy,
973 CreateCopyWithNewLayout(operand_layout.shape(), param));
974 for (auto user : param_users) {
975 TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy));
976 }
977 VLOG(2) << "New copy of " << operand->ToString() << " is "
978 << param_copy->ToString();
979 if (param == branch_comp->root_instruction()) {
980 branch_comp->set_root_instruction(param_copy,
981 /*accept_different_shape=*/true);
982 }
983 *FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) =
984 ShapeLayout(operand->shape());
985 return Status::OK();
986 }
987
988 TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
989 CreateCopyWithNewLayout(operand_layout.shape(), operand));
990
991 VLOG(4) << "New copy of " << operand->ToString() << " is "
992 << operand_copy->ToString();
993 return instruction->ReplaceOperandWith(operand_no, operand_copy);
994 }
995
SetupCopiedInstruction(const HloInstruction & instruction,HloInstruction * copy,const ShapeIndex & index)996 void LayoutAssignment::SetupCopiedInstruction(const HloInstruction& instruction,
997 HloInstruction* copy,
998 const ShapeIndex& index) {
999 if (instruction.has_sharding()) {
1000 // If the index is empty, we want to copy the whole sharding, in case the
1001 // sharding is a tuple sharding.
1002 HloSharding sharding =
1003 !index.empty() && instruction.sharding().IsTuple()
1004 ? instruction.sharding().GetSubSharding(instruction.shape(), index)
1005 : instruction.sharding();
1006 // We propagate the sharding to the copied instruction only if it is a
1007 // special sharding, like tiled ones.
1008 // Otherwise it is preferable to leave the new instruction without device,
1009 // and let the automatic device placer to choose the best location.
1010 auto device = sharding.UniqueDevice();
1011 if (!device || HloSharding::IsReservedDevice(*device)) {
1012 copy->set_sharding(sharding);
1013 }
1014 }
1015 copy->set_metadata(instruction.metadata());
1016 }
1017
CheckLayouts(HloModule * module)1018 Status LayoutAssignment::CheckLayouts(HloModule* module) {
1019 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
1020 TuplePointsToAnalysis::Run(module));
1021 for (auto* computation : module->MakeNonfusionComputations()) {
1022 for (auto* instruction : computation->instructions()) {
1023 // Verify every instruction has a layout and the layout is valid for the
1024 // shape.
1025 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1026 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
1027
1028 // Use points-to analysis to verify that every subshape element in the
1029 // output of the instruction matches the layout of the logical buffer
1030 // which could be the source of the subshape value.
1031 const PointsToSet& points_to_set =
1032 points_to_analysis->GetPointsToSet(instruction);
1033 TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus(
1034 [&instruction](ShapeIndex index,
1035 const PointsToSet::BufferList& buffers) -> Status {
1036 if (ShapeUtil::IsLeafIndex(instruction->shape(), index)) {
1037 const Shape& instruction_subshape =
1038 ShapeUtil::GetSubshape(instruction->shape(), index);
1039 for (const LogicalBuffer* buffer : buffers) {
1040 if (!Shape::Equal()
1041 .IgnoreDynamicDimension()
1042 .MinorToMajorOnlyInLayout()(instruction_subshape,
1043 buffer->shape())) {
1044 return InternalError(
1045 "Layout of instruction %s at index {%s} does not match "
1046 "source LogicalBuffer %s: %s vs %s",
1047 instruction->name(), absl::StrJoin(index, ","),
1048 buffer->ToString(),
1049 ShapeUtil::HumanStringWithLayout(instruction_subshape),
1050 ShapeUtil::HumanStringWithLayout(buffer->shape()));
1051 }
1052 }
1053 }
1054 return Status::OK();
1055 }));
1056
1057 // Verify instructions that have special layout constraints.
1058 switch (instruction->opcode()) {
1059 case HloOpcode::kCall:
1060 TF_RETURN_IF_ERROR(CheckCallLayout(
1061 instruction,
1062 FindOrDie(computation_layouts_, instruction->to_apply())));
1063 break;
1064 case HloOpcode::kCustomCall:
1065 TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
1066 break;
1067 case HloOpcode::kFusion:
1068 TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
1069 break;
1070 case HloOpcode::kParameter:
1071 TF_RETURN_IF_ERROR(CheckParameterLayout(
1072 instruction,
1073 FindOrDie(computation_layouts_, instruction->parent())));
1074 break;
1075 case HloOpcode::kConstant:
1076 TF_RETURN_IF_ERROR(CheckConstantLayout(instruction));
1077 break;
1078 case HloOpcode::kWhile:
1079 TF_RETURN_IF_ERROR(CheckWhileLayout(
1080 instruction,
1081 FindOrDie(computation_layouts_, instruction->while_condition()),
1082 FindOrDie(computation_layouts_, instruction->while_body())));
1083 break;
1084 case HloOpcode::kConditional: {
1085 std::vector<ComputationLayout> branch_computation_layouts;
1086 for (auto branch_computation : instruction->branch_computations()) {
1087 branch_computation_layouts.emplace_back(
1088 FindOrDie(computation_layouts_, branch_computation));
1089 }
1090 TF_RETURN_IF_ERROR(CheckConditionalLayout(
1091 instruction, absl::MakeSpan(branch_computation_layouts)));
1092 break;
1093 }
1094 default:
1095 break;
1096 }
1097 }
1098 }
1099 // Finally verify the result layout, if set, matches the layout of the entry
1100 // computation root.
1101 const ShapeLayout& result_layout =
1102 FindOrDie(computation_layouts_, module->entry_computation())
1103 .result_layout();
1104 if (result_layout.LayoutIsSet()) {
1105 TF_RET_CHECK(
1106 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
1107 module->result_shape(), result_layout.shape()));
1108 }
1109 return Status::OK();
1110 }
1111
LayoutAssignment(ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints,bool reverse_computation_order)1112 LayoutAssignment::LayoutAssignment(
1113 ComputationLayout* entry_computation_layout,
1114 ChannelLayoutConstraints* channel_constraints,
1115 bool reverse_computation_order)
1116 : entry_computation_layout_(entry_computation_layout),
1117 saved_entry_computation_layout_(*entry_computation_layout),
1118 reverse_computation_order_(reverse_computation_order),
1119 channel_layout_constraints_(channel_constraints) {
1120 if (channel_layout_constraints_ != nullptr) {
1121 // Save a copy of the input ChannelLayoutConstraints so that we can reset it
1122 // if we have to undo previous operations (ClearPreviousPassSideEffects()).
1123 channel_constraints_ = *channel_layout_constraints_;
1124 }
1125 VLOG(1) << "Entry computation layout given to layout assignment: "
1126 << entry_computation_layout_->ToString();
1127 }
1128
ChooseOperandLayoutFromOutputLayout(const Layout & output_layout,const HloInstruction * instruction,int64_t operand_no)1129 std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
1130 const Layout& output_layout, const HloInstruction* instruction,
1131 int64_t operand_no) {
1132 const HloInstruction* operand = instruction->operand(operand_no);
1133 CHECK(instruction->shape().IsArray());
1134 CHECK(operand->shape().IsArray());
1135 if (!ShapeUtil::IsScalar(operand->shape()) &&
1136 operand->shape().rank() == instruction->shape().rank() &&
1137 !InstructionCanChangeLayoutInstance(instruction)) {
1138 // Propagate the result layout to the operand layout if the instruction
1139 // requires the same layout out for the result and the operand.
1140 //
1141 // For elementwise operations, using the same layout for the operands and
1142 // the result also has the following benefits:
1143 // 1) the elementwise operation can reuse its operand's buffer, and
1144 // 2) the input and output elements can reuse the same linear index.
1145 return absl::make_unique<Layout>(output_layout);
1146 }
1147
1148 if (instruction->opcode() == HloOpcode::kReshape) {
1149 // Prefer the operand layout that makes the reshape an bitcast. If any
1150 // dimension bound is 1 in the operand shape, there may be several such
1151 // layouts. So if 'output_layout' is the default layout, try if the
1152 // reshape is a bitcast when using the same layout. This may avoid copy
1153 // operations. For similar reasons, if the operand and output have the same
1154 // rank, try to match the operand's layout to the output.
1155 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1156 ShapeUtil::TrueRank(instruction->shape()) == 1) {
1157 // Don't assign a layout in case of R1 -> effective R1 reshape.
1158 return nullptr;
1159 }
1160
1161 const Shape& output_shape = instruction->shape();
1162 Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1163 output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
1164 LayoutUtil::MinorToMajor(output_layout));
1165 Shape operand_shape = operand->shape();
1166 *operand_shape.mutable_layout() =
1167 LayoutUtil::GetDefaultLayoutForShape(operand_shape);
1168 auto aligned_operand_shape =
1169 ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape);
1170 if (aligned_operand_shape) {
1171 auto operand_layout = aligned_operand_shape.value().layout();
1172 TF_CHECK_OK(
1173 LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape));
1174 return absl::make_unique<Layout>(operand_layout);
1175 }
1176 }
1177
1178 if (instruction->opcode() == HloOpcode::kTranspose) {
1179 // Pick the operand layout that makes the transpose a bitcast.
1180 int64_t rank = instruction->shape().rank();
1181 std::vector<int64> new_minor_to_major(rank);
1182 for (int64_t i = 0; i < rank; ++i) {
1183 int64_t output_dim = LayoutUtil::Minor(output_layout, i);
1184 int64_t operand_dim = instruction->dimensions(output_dim);
1185 new_minor_to_major[i] = operand_dim;
1186 }
1187 Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1188 TF_CHECK_OK(
1189 LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
1190 return absl::make_unique<Layout>(operand_layout);
1191 }
1192
1193 return nullptr;
1194 }
1195
ChooseOutputLayoutFromOperandLayout(const Layout & operand_layout,const HloInstruction * user,int64_t operand_no)1196 std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
1197 const Layout& operand_layout, const HloInstruction* user,
1198 int64_t operand_no) {
1199 const HloInstruction* operand = user->operand(operand_no);
1200
1201 CHECK(user->shape().IsArray() && operand->shape().IsArray());
1202
1203 if (!ShapeUtil::IsScalar(operand->shape()) &&
1204 operand->shape().rank() == user->shape().rank() &&
1205 !InstructionCanChangeLayoutInstance(user)) {
1206 // Assign users the same layout as the operand.
1207 return absl::make_unique<Layout>(operand_layout);
1208 }
1209
1210 if (user->opcode() == HloOpcode::kReshape) {
1211 // Prefer the user layout that makes the reshape an bitcast. If any
1212 // dimension bound is 1 in the user shape, there may be several such
1213 // layouts. So if 'operand_layout' is the default layout, try if the
1214 // reshape is a bitcast when using the same layout. This may avoid copy
1215 // operations. For similar reasons, if the operand and output have the same
1216 // rank, try to match the outputs's layout to the operand.
1217 if (ShapeUtil::TrueRank(operand->shape()) == 1 &&
1218 ShapeUtil::TrueRank(user->shape()) == 1) {
1219 // Don't assign a layout in case of R1 -> effective R1 reshape.
1220 return nullptr;
1221 }
1222 Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
1223 operand->shape().element_type(),
1224 AsInt64Slice(operand->shape().dimensions()),
1225 LayoutUtil::MinorToMajor(operand_layout));
1226 Shape output_shape = user->shape();
1227 *output_shape.mutable_layout() =
1228 LayoutUtil::GetDefaultLayoutForShape(output_shape);
1229 auto aligned_user_shape =
1230 ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape);
1231 if (aligned_user_shape) {
1232 auto user_layout = aligned_user_shape.value().layout();
1233 TF_CHECK_OK(
1234 LayoutUtil::ValidateLayoutForShape(user_layout, output_shape));
1235 return absl::make_unique<Layout>(user_layout);
1236 }
1237 }
1238
1239 if (user->opcode() == HloOpcode::kTranspose) {
1240 // Pick the user layout that makes the transpose a bitcast.
1241 int64_t rank = user->shape().rank();
1242 std::vector<int64> new_minor_to_major(rank);
1243 auto inverse_dimensions = InversePermutation(user->dimensions());
1244 for (int64_t i = 0; i < rank; ++i) {
1245 int64_t operand_dim = LayoutUtil::Minor(operand_layout, i);
1246 int64_t user_dim = inverse_dimensions[operand_dim];
1247 new_minor_to_major[i] = user_dim;
1248 }
1249 Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
1250 TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
1251 return absl::make_unique<Layout>(user_layout);
1252 }
1253
1254 return nullptr;
1255 }
1256
PropagateConstraints(LayoutConstraints * constraints)1257 Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
1258 // Gathers all initial constraints in a worklist and propagates them in
1259 // depth-first order. DFS order seems to be better than BFS because a
1260 // constraint is propagated as far as possible before propagating unrelated
1261 // constraints which makes it less likely that conflicting constraints will be
1262 // propagated to instructions. However, we should experiment with other orders
1263 // too.
1264 std::deque<const LayoutConstraint*> worklist;
1265
1266 // Lambda for moving newly added constraints to the worklist.
1267 auto add_new_constraints_to_worklist = [constraints, &worklist]() {
1268 // Add constraints to the front of the deque for DFS ordering.
1269 for (auto* constraint : constraints->ConsumeAddedConstraints()) {
1270 if (constraint->dfs()) {
1271 worklist.push_front(constraint);
1272 } else {
1273 VLOG(3) << "push back constraint for propagation : "
1274 << constraint->ToString();
1275 worklist.push_back(constraint);
1276 }
1277 }
1278 };
1279 add_new_constraints_to_worklist();
1280
1281 while (!worklist.empty()) {
1282 const LayoutConstraint* layout_constraint = worklist.front();
1283 worklist.pop_front();
1284 VLOG(2) << "Propagating " << layout_constraint->ToString()
1285 << " to its neighbors with priority = "
1286 << layout_constraint->priority() << "\n";
1287 if (auto* buffer_constraint =
1288 dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
1289 TF_RETURN_IF_ERROR(
1290 PropagateBufferConstraint(*buffer_constraint, constraints));
1291 } else if (auto* operand_constraint =
1292 dynamic_cast<const OperandLayoutConstraint*>(
1293 layout_constraint)) {
1294 TF_RETURN_IF_ERROR(
1295 PropagateOperandConstraint(*operand_constraint, constraints));
1296 } else if (auto* result_constraint =
1297 dynamic_cast<const ResultLayoutConstraint*>(
1298 layout_constraint)) {
1299 TF_RETURN_IF_ERROR(
1300 PropagateResultConstraint(*result_constraint, constraints));
1301 } else {
1302 LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
1303 }
1304
1305 add_new_constraints_to_worklist();
1306 }
1307 return Status::OK();
1308 }
1309
1310 namespace {
1311
1312 // Returns a vector containing all array-shaped uses (instruction and operand
1313 // number) of the given logical buffer or its aliases.
GetArrayUsesOfBuffer(const LogicalBuffer & buffer,const TuplePointsToAnalysis & points_to_analysis)1314 std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
1315 const LogicalBuffer& buffer,
1316 const TuplePointsToAnalysis& points_to_analysis) {
1317 CHECK(buffer.IsArray());
1318 std::vector<std::pair<const HloInstruction*, int64>> uses;
1319 for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
1320 if (!buffer_alias.instruction()->shape().IsArray()) {
1321 continue;
1322 }
1323 // This alias must be the top-level (index == {}) of the instruction's
1324 // result because the instruction produces an array.
1325 CHECK(buffer_alias.index().empty());
1326
1327 // Add all uses of the instruction's output.
1328 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
1329 for (int64_t operand_no :
1330 user->OperandIndices(buffer_alias.instruction())) {
1331 uses.emplace_back(user, operand_no);
1332 }
1333 }
1334 }
1335 return uses;
1336 }
1337
1338 } // namespace
1339
PropagateUseConstraintToDefs(const ShapeLayout & shape_layout,const HloInstruction * instruction,LayoutConstraints * constraints)1340 Status LayoutAssignment::PropagateUseConstraintToDefs(
1341 const ShapeLayout& shape_layout, const HloInstruction* instruction,
1342 LayoutConstraints* constraints) {
1343 // Try to set all logical buffers which may be sources of the given operand to
1344 // match the given layout.
1345 const PointsToSet& points_to_set =
1346 constraints->points_to_analysis().GetPointsToSet(instruction);
1347 return points_to_set.ForEachElementWithStatus(
1348 [&shape_layout, constraints](
1349 const ShapeIndex& index,
1350 const PointsToSet::BufferList& buffers) -> Status {
1351 if (ShapeUtil::IsLeafIndex(shape_layout.shape(), index)) {
1352 for (const LogicalBuffer* buffer : buffers) {
1353 if (constraints->BufferLayout(*buffer) == nullptr &&
1354 buffer->shape().IsArray()) {
1355 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1356 ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
1357 *buffer, /*mandatory=*/true));
1358 }
1359 }
1360 }
1361 return Status::OK();
1362 });
1363 }
1364
1365 namespace {
1366 // A transpose or a reshape that only changes trivial dimensions have meaningful
1367 // layouts that are valuable to propagate in a depthfirst manner to avoid
1368 // unassigned layouts in the graph.
InstructionShouldPropagateDepthFirst(const HloInstruction & hlo,bool forward_propagation=true)1369 bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
1370 bool forward_propagation = true) {
1371 switch (hlo.opcode()) {
1372 case HloOpcode::kFusion:
1373 return hlo.IsCustomFusion();
1374 case HloOpcode::kGather:
1375 return true;
1376 case HloOpcode::kReshape:
1377 return hlo.operand(0)->shape().rank() == 1 ||
1378 (forward_propagation &&
1379 std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
1380 case HloOpcode::kScatter:
1381 case HloOpcode::kTranspose:
1382 return true;
1383 default:
1384 return false;
1385 }
1386 }
1387
1388 } // namespace
1389
PropagateOperandConstraint(const OperandLayoutConstraint & operand_constraint,LayoutConstraints * constraints)1390 Status LayoutAssignment::PropagateOperandConstraint(
1391 const OperandLayoutConstraint& operand_constraint,
1392 LayoutConstraints* constraints) {
1393 // Try to set the layout of the logical buffers in the given operand to match
1394 // the constrained layout. This avoids copies.
1395 TF_RETURN_IF_ERROR(
1396 PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
1397 operand_constraint.operand(), constraints));
1398
1399 // For array-shaped operands and user instructions try to pick a minimum cost
1400 // layout. For example, if the operand of an elementwise instruction is
1401 // constrained to a certain layout we want the output of the instruction to
1402 // have the same layout.
1403 //
1404 // If the user is not array-shaped, we still want to propagate the layout
1405 // to siblings if the instruction can't change layout. This is to represent
1406 // the information that non-layout-changing instructions should have the same
1407 // layout for the operands with the same ranks.
1408 const HloInstruction* operand = operand_constraint.operand();
1409 const HloInstruction* user = operand_constraint.instruction();
1410 if (!operand->shape().IsArray()) {
1411 return Status::OK();
1412 }
1413
1414 if (user->opcode() == HloOpcode::kAllReduce) {
1415 const auto shape_index =
1416 user->operand_count() == 1
1417 ? ShapeIndex()
1418 : ShapeIndex({operand_constraint.operand_no()});
1419 TF_ASSIGN_OR_RETURN(const LogicalBuffer* buffer,
1420 constraints->points_to_analysis().GetBufferDefinedAt(
1421 user, shape_index));
1422 const BufferLayoutConstraint* constraint =
1423 constraints->GetBufferLayoutConstraint(*buffer);
1424 if (constraint == nullptr) {
1425 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1426 operand_constraint.shape_layout().layout(), *buffer,
1427 /*mandatory=*/false));
1428 }
1429 }
1430 if (InstructionCanChangeLayoutInstance(user) && !user->shape().IsArray()) {
1431 return Status::OK();
1432 }
1433
1434 // Only try to choose a low cost layout if the instruction 'user' defines its
1435 // output (ie, doesn't forward a buffer from elsewhere).
1436 if (constraints->AnyOperandBufferForwarded(user,
1437 operand_constraint.operand_no())) {
1438 return Status::OK();
1439 }
1440
1441 int64_t operand_rank = operand->shape().rank();
1442 if (operand_rank <= 1) {
1443 return Status::OK();
1444 }
1445
1446 // Propagate layouts between operands of the same instruction. This is a
1447 // constraint on non-layout-changing instructions.
1448 if (!InstructionCanChangeLayoutInstance(user)) {
1449 // Only propgate the layout of the largest concatenate operand.
1450 if (user->opcode() == HloOpcode::kConcatenate) {
1451 for (int64_t operand_no = 0; operand_no < user->operand_count();
1452 ++operand_no) {
1453 const HloInstruction* sibling = user->operand(operand_no);
1454 if (sibling == operand) {
1455 continue;
1456 }
1457 if (sibling->shape().dimensions(user->concatenate_dimension()) >
1458 operand->shape().dimensions(user->concatenate_dimension())) {
1459 return Status::OK();
1460 }
1461 }
1462 }
1463 // Make sure all siblings have the same layout as the operand.
1464 for (int64_t operand_no = 0; operand_no < user->operand_count();
1465 ++operand_no) {
1466 if (user->operand(operand_no) == operand) {
1467 continue;
1468 }
1469 const HloInstruction* sibling = user->operand(operand_no);
1470 const int64_t sibling_rank = sibling->shape().rank();
1471 if (sibling_rank <= 1) {
1472 continue;
1473 }
1474 if (operand_rank != sibling_rank) {
1475 continue;
1476 }
1477 const OperandLayoutConstraint* constraint =
1478 constraints->GetOperandLayoutConstraint(user, operand_no);
1479 if (constraint != nullptr) {
1480 // Due to the DFS of the propagation we can end up here when operand_no
1481 // has a layout set that hasn't been propagated yet (is still on the
1482 // stack of layouts to propagate).
1483 // We can continue here and leave the operands with different layouts,
1484 // as we will either:
1485 // - overwrite the current operand when the DFS gets back to propagating
1486 // operand(operand_no) to its siblings
1487 // - overwrite operand(operand_no)'s layout with a mandatory layout if
1488 // we continue to propagate our layout to the result, and then
1489 // backwards into all operands (if the result is an array of rank > 1)
1490 continue;
1491 }
1492 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1493 operand_constraint.shape_layout().layout(), user, operand_no,
1494 /*mandatory=*/false));
1495 }
1496 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1497 user->shape(),
1498 [&](const Shape& subshape, const ShapeIndex& shape_index) {
1499 if (subshape.IsTuple()) {
1500 return Status::OK();
1501 }
1502 if (subshape.rank() <= 1) {
1503 return Status::OK();
1504 }
1505
1506 // Assign the right layout to input fusion of higher rank reduce
1507 // operations.
1508 if (subshape.rank() != operand->shape().rank()) {
1509 return Status::OK();
1510 }
1511 if (!constraints->points_to_analysis()
1512 .InstructionDefinesBufferAtIndex(user, shape_index)) {
1513 return Status::OK();
1514 }
1515 // TODO(b/67641796): Are there cases except fusion that use this code
1516 // path?
1517 TF_ASSIGN_OR_RETURN(
1518 const LogicalBuffer* buffer,
1519 constraints->points_to_analysis().GetBufferDefinedAt(
1520 user, shape_index));
1521 // Make sure the output has the same layout as the operand.
1522 const BufferLayoutConstraint* constraint =
1523 constraints->GetBufferLayoutConstraint(*buffer);
1524 // If we already have a constraint for the buffer it was assigned but
1525 // hasn't propagated yet. This can happen with diamond-shaped graphs
1526 // where one path is first evaluated in depth-first order (we're here)
1527 // and the other path is propagated later. We don't set the layout
1528 // here as it will always be overwritten later.
1529 if (constraint == nullptr) {
1530 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1531 operand_constraint.shape_layout().layout(), *buffer,
1532 /*mandatory=*/false));
1533 }
1534 return Status::OK();
1535 }));
1536 return Status::OK();
1537 }
1538 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1539 user->shape(), [&](const Shape& subshape, const ShapeIndex& shape_index) {
1540 if (subshape.IsTuple()) {
1541 return Status::OK();
1542 }
1543 if (subshape.rank() <= 1) {
1544 return Status::OK();
1545 }
1546 if (!constraints->points_to_analysis().InstructionDefinesBufferAtIndex(
1547 user, shape_index)) {
1548 return Status::OK();
1549 }
1550 TF_ASSIGN_OR_RETURN(
1551 const LogicalBuffer* buffer,
1552 constraints->points_to_analysis().GetBufferDefinedAt(user,
1553 shape_index));
1554 if (constraints->BufferLayout(*buffer) == nullptr ||
1555 !constraints->GetBufferLayoutConstraint(*buffer)->mandatory()) {
1556 std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
1557 operand_constraint.shape_layout().layout(), user,
1558 operand_constraint.operand_no());
1559 if (layout != nullptr) {
1560 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1561 *layout, *buffer,
1562 /*mandatory=*/user->opcode() == HloOpcode::kReduce,
1563 /*dfs=*/InstructionShouldPropagateDepthFirst(*user)));
1564 }
1565 }
1566 return Status::OK();
1567 }));
1568 return Status::OK();
1569 }
1570
PropagateBufferConstraintToOperands(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1571 Status LayoutAssignment::PropagateBufferConstraintToOperands(
1572 const BufferLayoutConstraint& buffer_constraint,
1573 LayoutConstraints* constraints) {
1574 VLOG(5) << "PropagateBufferConstraintToOperands: "
1575 << buffer_constraint.ToString();
1576 const LogicalBuffer& buffer = buffer_constraint.buffer();
1577
1578 const HloInstruction* instruction = buffer.instruction();
1579 if (IsAtMostRank1(instruction->shape())) {
1580 return Status::OK();
1581 }
1582
1583 if (instruction->opcode() == HloOpcode::kAllReduce) {
1584 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1585 buffer_constraint.layout(), instruction,
1586 instruction->operand_count() == 1 ? 0 : buffer.index()[0],
1587 /*mandatory=*/true));
1588 return Status::OK();
1589 }
1590 for (int64_t operand_no = 0; operand_no < instruction->operand_count();
1591 ++operand_no) {
1592 const HloInstruction* operand = instruction->operand(operand_no);
1593 if (IsAtMostRank1(operand->shape())) {
1594 continue;
1595 }
1596 if (!InstructionCanChangeLayoutInstance(instruction)) {
1597 // Copy the layout to the operand.
1598 if (buffer.IsArray() && operand->shape().IsArray() &&
1599 operand->shape().rank() ==
1600 LayoutUtil::MinorToMajor(buffer_constraint.layout()).size()) {
1601 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1602 buffer_constraint.layout(), instruction, operand_no,
1603 /*mandatory=*/true));
1604 }
1605 } else {
1606 if (!buffer.IsTopLevel() ||
1607 !instruction->operand(operand_no)->shape().IsArray()) {
1608 continue; // Don't touch buffers that are internal to a tuple.
1609 }
1610 VLOG(6) << "Propagating constraint to operand " << operand_no << " of "
1611 << instruction->ToShortString();
1612 // Assign a layout if there is no constraint already.
1613 const OperandLayoutConstraint* constraint =
1614 constraints->GetOperandLayoutConstraint(instruction, operand_no);
1615 if (constraint == nullptr || !constraint->mandatory()) {
1616 std::unique_ptr<Layout> operand_layout =
1617 ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
1618 instruction, operand_no);
1619 if (operand_layout != nullptr) {
1620 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1621 *operand_layout, instruction, operand_no, /*mandatory=*/false,
1622 /*dfs=*/
1623 InstructionShouldPropagateDepthFirst(
1624 *instruction, /*forward_propagation=*/false)));
1625 }
1626 } else {
1627 VLOG(6) << "Operand already has a constraint "
1628 << constraint->ToString();
1629 }
1630 }
1631 }
1632 return Status::OK();
1633 }
1634
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1635 Status LayoutAssignment::PropagateBufferConstraint(
1636 const BufferLayoutConstraint& buffer_constraint,
1637 LayoutConstraints* constraints) {
1638 // Only propagate array layouts.
1639 const LogicalBuffer& buffer = buffer_constraint.buffer();
1640 if (!buffer.IsArray()) {
1641 return Status::OK();
1642 }
1643 TF_RETURN_IF_ERROR(
1644 PropagateBufferConstraintToOperands(buffer_constraint, constraints));
1645 return PropagateBufferConstraintToUses(buffer_constraint, constraints);
1646 }
1647
PropagateBufferConstraintToUses(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)1648 Status LayoutAssignment::PropagateBufferConstraintToUses(
1649 const BufferLayoutConstraint& buffer_constraint,
1650 LayoutConstraints* constraints) {
1651 const LogicalBuffer& buffer = buffer_constraint.buffer();
1652 TF_RET_CHECK(buffer.IsArray());
1653
1654 // Propagate the layout to all array uses of the logical buffer. This skips
1655 // uses of the buffer where the buffer is the element of a tuple.
1656 for (const auto& user_operand_no :
1657 GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
1658 const HloInstruction* user = user_operand_no.first;
1659 int64_t operand_no = user_operand_no.second;
1660 // Only add an operand constraint if the user does not forward the buffer
1661 // because this case is not handled is SetOperandLayout.
1662 if (constraints->OperandLayout(user, operand_no) == nullptr &&
1663 !constraints->AnyOperandBufferForwarded(user, operand_no)) {
1664 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
1665 buffer_constraint.layout(), user, operand_no, /*mandatory=*/false));
1666 }
1667 }
1668
1669 // Propagate to backedges of kWhile.
1670 CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
1671 if (node.caller_callsites().size() != 1) {
1672 return Status::OK();
1673 }
1674 const HloInstruction* parent = node.caller_callsites()[0].instruction();
1675 if (parent->opcode() != HloOpcode::kWhile) {
1676 return Status::OK();
1677 }
1678
1679 for (HloInstruction* user : buffer.instruction()->users()) {
1680 if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
1681 continue;
1682 }
1683 if (user->parent()->root_instruction() == user) {
1684 VLOG(3) << "Propagating layout through backedge"
1685 << buffer_constraint.layout().ToString();
1686 int64_t index = user->operand_index(buffer.instruction());
1687 TF_ASSIGN_OR_RETURN(
1688 auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
1689 user->parent()->parameter_instruction(0), {index}));
1690
1691 TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
1692 buffer_constraint.layout(), *buffer, /*mandatory=*/false));
1693 }
1694 }
1695
1696 return Status::OK();
1697 }
1698
PropagateResultConstraint(const ResultLayoutConstraint & layout_constraint,LayoutConstraints * constraints)1699 Status LayoutAssignment::PropagateResultConstraint(
1700 const ResultLayoutConstraint& layout_constraint,
1701 LayoutConstraints* constraints) {
1702 // Propagate the use constraint of the root instruction up to the logical
1703 // buffers which make up the result.
1704 return PropagateUseConstraintToDefs(
1705 layout_constraint.shape_layout(),
1706 constraints->computation()->root_instruction(), constraints);
1707 }
1708
1709 namespace {
1710
1711 // Infers the layout of the array at the given index in the given instruction's
1712 // output using points-to analysis. Precondition: The given instruction must
1713 // not produce this array value (that is, the array is forwarded from the
1714 // instruction's operands).
InferArrayLayout(const TuplePointsToAnalysis & points_to_analysis,HloInstruction * instruction,const ShapeIndex & index)1715 StatusOr<Layout> InferArrayLayout(
1716 const TuplePointsToAnalysis& points_to_analysis,
1717 HloInstruction* instruction, const ShapeIndex& index) {
1718 // This function should only be called for array shapes which don't yet have
1719 // layouts.
1720 const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
1721 TF_RET_CHECK(subshape.IsArray());
1722 TF_RET_CHECK(!subshape.has_layout());
1723
1724 // The instruction should not define the buffer at this index.
1725 TF_RET_CHECK(
1726 !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index))
1727 << instruction->ToString();
1728
1729 const auto& source_buffers =
1730 points_to_analysis.GetPointsToSet(instruction).element(index);
1731 TF_RET_CHECK(!source_buffers.empty());
1732
1733 // Verify the layout is the same for every LogicalBuffer which this location
1734 // ('instruction' and 'index') points to.
1735 const Layout* first_buffer_layout = nullptr;
1736 for (const LogicalBuffer* source_buffer : source_buffers) {
1737 if (!source_buffer->shape().has_layout()) {
1738 // This should not happen because we've assigned layouts to all
1739 // instructions preceding this one.
1740 return InternalError("LogicalBuffer %s does not have a layout",
1741 source_buffer->ToString());
1742 }
1743
1744 if (first_buffer_layout == nullptr) {
1745 first_buffer_layout = &source_buffer->shape().layout();
1746 } else if (!Layout::Equal().MinorToMajorOnly()(
1747 source_buffer->shape().layout(), *first_buffer_layout)) {
1748 // The points-to set is ambiguous for this index and the different source
1749 // buffers have different layouts. This case is possible in valid XLA
1750 // computations because we do not propagate BufferLayoutConstraints to all
1751 // LogicalBuffers which may alias the constrained LogicalBuffer at some
1752 // point in the computation.
1753 return FailedPrecondition(
1754 "Array at index {%s} in instruction %s aliases buffers %s "
1755 "and %s which have different layouts",
1756 absl::StrJoin(index, ","), instruction->name(),
1757 source_buffers[0]->ToString(), source_buffer->ToString());
1758 }
1759 }
1760
1761 return *first_buffer_layout;
1762 }
1763
1764 // For fusion instructions, set the layout of each fused parameter instruction
1765 // to match the layout of its corresponding fusion instruction operand. Also,
1766 // set the layout of the fused root to match the layout of the fusion
1767 // instruction itself.
SetFusionLayouts(HloInstruction * fusion)1768 Status SetFusionLayouts(HloInstruction* fusion) {
1769 TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
1770 for (auto* fused_instruction :
1771 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
1772 if (fused_instruction->opcode() == HloOpcode::kParameter) {
1773 const HloInstruction* fusion_operand =
1774 fusion->operand(fused_instruction->parameter_number());
1775 DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
1776 fused_instruction->shape()));
1777 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1778 fusion_operand->shape(), fused_instruction->mutable_shape()));
1779 } else if (fused_instruction == fusion->fused_expression_root()) {
1780 // The layout of the root of the fused expression must match the fusion
1781 // instruction layout.
1782 DCHECK(
1783 ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
1784 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1785 fusion->shape(), fused_instruction->mutable_shape()));
1786 } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) {
1787 // A GTE inherits its layout from its operand (which should ultimately be
1788 // a parameter).
1789 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1790 fused_instruction->operand(0)->shape().tuple_shapes(
1791 fused_instruction->tuple_index()),
1792 fused_instruction->mutable_shape()));
1793 } else if (fused_instruction->opcode() == HloOpcode::kConstant) {
1794 // Give constants the layout of their literal.
1795 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
1796 fused_instruction->literal().shape(),
1797 fused_instruction->mutable_shape()));
1798 } else if (fused_instruction->opcode() == HloOpcode::kInfeed) {
1799 // Nop; leave the infeed layout alone.
1800 } else if (!fusion->IsCustomFusion()) {
1801 // Other instructions don't have layouts inside of fusion nodes.
1802 // But do not clear layouts for other instructions in custom fusion nodes.
1803 LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
1804 }
1805 }
1806
1807 return Status::OK();
1808 }
1809
1810 } // namespace
1811
AssignLayouts(const LayoutConstraints & constraints,HloComputation * computation)1812 Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
1813 HloComputation* computation) {
1814 VLOG(2) << "Assigning layouts to computation: " << computation->name();
1815 XLA_VLOG_LINES(2, computation->ToString());
1816 XLA_VLOG_LINES(2, constraints.ToString());
1817
1818 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
1819 LayoutUtil::ClearLayout(instruction->mutable_shape());
1820
1821 // Set the layouts of the array shapes this instruction defines as indicated
1822 // by the respective BufferLayoutConstraints. Any array shapes in the output
1823 // of the instruction which are not defined by the instruction (eg, array
1824 // elements in a Tuple instruction) will be assigned below via inference.
1825 for (const LogicalBuffer* buffer :
1826 constraints.points_to_analysis().GetBuffersDefinedByInstruction(
1827 instruction)) {
1828 if (!buffer->shape().IsArray()) {
1829 continue;
1830 }
1831
1832 TF_RET_CHECK(buffer->instruction() == instruction);
1833 const Layout* buffer_layout = constraints.BufferLayout(*buffer);
1834 TF_RET_CHECK(buffer_layout != nullptr);
1835
1836 if (instruction->opcode() == HloOpcode::kConstant) {
1837 // For constants, we also need to change the layout of the internal
1838 // literal.
1839 instruction->RelayoutConstant(*buffer_layout, buffer->index());
1840 } else {
1841 Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
1842 instruction->mutable_shape(), buffer->index());
1843 *buffer_subshape->mutable_layout() = *buffer_layout;
1844 }
1845 }
1846
1847 // Any remaining layouts in the output of the instruction must be
1848 // inferrable using points-to analysis.
1849 TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
1850 instruction->mutable_shape(),
1851 [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
1852 if (subshape->has_layout() || !subshape->IsArray()) {
1853 return Status::OK();
1854 }
1855 // Set Layout of subshape to match layout of LogicalBuffer which
1856 // produces it.
1857 TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
1858 InferArrayLayout(constraints.points_to_analysis(),
1859 instruction, index));
1860 return Status::OK();
1861 }));
1862 // Create a copy of an operand if the operand instruction's layout does not
1863 // match the use constraint (OperandLayoutConstraint).
1864 for (int64_t operand_no = 0; operand_no < instruction->operand_count();
1865 ++operand_no) {
1866 const ShapeLayout* operand_layout =
1867 constraints.OperandLayout(instruction, operand_no);
1868 if (operand_layout != nullptr) {
1869 TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
1870 instruction, operand_no));
1871 } else {
1872 VLOG(2) << "operand " << operand_no << " has no constraint";
1873 }
1874 }
1875
1876 // Process instructions that contain nested computations and may require
1877 // additional layouts to be assigned on the instructions nested inside.
1878 switch (instruction->opcode()) {
1879 case HloOpcode::kFusion:
1880 TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
1881 break;
1882 case HloOpcode::kCall:
1883 if (reverse_computation_order_) {
1884 ProgramShape program_shape;
1885 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
1886 *program_shape.add_parameters() = instruction->operand(i)->shape();
1887 }
1888 *program_shape.mutable_result() = instruction->shape();
1889 computation_layouts_.emplace(
1890 instruction->to_apply(),
1891 ComputationLayout(program_shape,
1892 /*ignore layout=*/false));
1893 }
1894 break;
1895 case HloOpcode::kConditional:
1896 // If the branches don't yet have layouts, propagate existing layout
1897 // inside the branches.
1898 if (reverse_computation_order_) {
1899 for (int i = 0; i < instruction->branch_count(); ++i) {
1900 ProgramShape program_shape;
1901 auto* branch_operand = instruction->operand(i + 1);
1902 *program_shape.add_parameters() = branch_operand->shape();
1903 *program_shape.mutable_result() = instruction->shape();
1904 computation_layouts_.emplace(
1905 instruction->branch_computation(i),
1906 ComputationLayout(program_shape,
1907 /*ignore layout=*/false));
1908 }
1909 }
1910 break;
1911 case HloOpcode::kWhile:
1912 // If the loop body doesn't have layouts, propagate existing one inside.
1913 if (reverse_computation_order_) {
1914 VLOG(2) << "Populating while loop constraints inside loop body.";
1915 VLOG(2) << instruction->ToString();
1916 auto init_shape = instruction->operand(0)->shape();
1917 ProgramShape body_shape;
1918 *body_shape.add_parameters() = init_shape;
1919 *body_shape.mutable_result() = init_shape;
1920 computation_layouts_.emplace(
1921 instruction->while_body(),
1922 ComputationLayout(body_shape,
1923 /*ignore layout=*/false));
1924 VLOG(2) << "Populating while loop constraints inside loop condition.";
1925 VLOG(2) << instruction->ToString();
1926 ProgramShape condition_shape;
1927 *condition_shape.add_parameters() = init_shape;
1928 *condition_shape.mutable_result() = ShapeUtil::MakeScalarShape(PRED);
1929 }
1930 break;
1931 default:
1932 break;
1933 }
1934 // Execute extra verification step once the layout has been finalized.
1935 TF_RETURN_IF_ERROR(Verify(instruction));
1936
1937 // Shape must be valid.
1938 TF_RETURN_IF_ERROR(
1939 ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape()));
1940
1941 // Verify all layouts in the shape have been set.
1942 TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
1943 }
1944 return Status::OK();
1945 }
1946
CalculateComputationLayout(HloComputation * computation)1947 Status LayoutAssignment::CalculateComputationLayout(
1948 HloComputation* computation) {
1949 if (!reverse_computation_order_ ||
1950 computation_layouts_.find(computation) == computation_layouts_.end()) {
1951 ComputationLayout computation_layout(computation->ComputeProgramShape(),
1952 /*ignore_layouts=*/false);
1953 InsertOrDie(&computation_layouts_, computation, computation_layout);
1954 VLOG(2) << " Calculated ComputationLayout = "
1955 << computation_layout.ToString();
1956 }
1957 return Status::OK();
1958 }
1959
ClearComputationLayouts(HloComputation * computation)1960 Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
1961 // Clear existing layouts of the instructions. All layouts must be assigned
1962 // by the LayoutAssignment pass, except for those on parameters, the
1963 // computation result, and a couple special cases. The former two are
1964 // specified in computation_layout. Clearing the layouts here avoids hiding
1965 // potential bugs in the layout assignment pass that may accidentally use the
1966 // existing layout.
1967 for (HloInstruction* instruction : computation->instructions()) {
1968 if (instruction->opcode() == HloOpcode::kBitcast) {
1969 // bitcasts are inherently layout sensitive and so a bitcast instruction
1970 // present in the IR before layout assignment is a bug.
1971 return InternalError(
1972 "Unexpected bitcast operation seen during layout assignment: %s.",
1973 instruction->ToString());
1974 }
1975 // Some instructions carry mandatory layouts in their shape.
1976 if (instruction->opcode() != HloOpcode::kInfeed &&
1977 !IsLayoutConstrainedCustomCall(instruction) &&
1978 !IsLayoutConstrainedCollective(instruction)) {
1979 LayoutUtil::ClearLayout(instruction->mutable_shape());
1980 }
1981 }
1982 return Status::OK();
1983 }
1984
RunOnComputation(ComputationLayout * computation_layout,HloComputation * computation,ChannelLayoutConstraints * channel_constraints)1985 Status LayoutAssignment::RunOnComputation(
1986 ComputationLayout* computation_layout, HloComputation* computation,
1987 ChannelLayoutConstraints* channel_constraints) {
1988 VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
1989 << ")";
1990
1991 // Must be run before clearing layouts.
1992 TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
1993
1994 TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
1995 if (computation_layout != nullptr) {
1996 auto it = computation_layouts_.find(computation);
1997 if (it == computation_layouts_.end()) {
1998 VLOG(2) << " New ComputationLayout = " << computation_layout->ToString();
1999 computation_layouts_.emplace(computation, *computation_layout);
2000 } else {
2001 TF_RET_CHECK(computation_layout == &it->second ||
2002 computation_layout == entry_computation_layout_);
2003 VLOG(2) << " Existing ComputationLayout = "
2004 << computation_layout->ToString();
2005 }
2006 } else {
2007 VLOG(2) << " No ComputationLayout specified (will be calculated)";
2008 }
2009
2010 // Construct LayoutConstraints with all layout constraints of the computation.
2011 LayoutConstraints constraints(*points_to_analysis_, computation);
2012
2013 // Add constraints required for correctness on all backends (eg, entry
2014 // parameter layout constraints).
2015 TF_RETURN_IF_ERROR(AddMandatoryConstraints(
2016 computation_layout, channel_constraints, computation, &constraints));
2017
2018 // Add any backend-specific constraints.
2019 TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
2020
2021 // Propagates layouts from mandatory and backend constraints.
2022 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
2023
2024 // Prior to applying default layouts, we take note of all HLO instructions
2025 // which lack a layout constraint.
2026 for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
2027 unconstrained_layout_instructions_.insert(
2028 points_to_analysis_->GetBuffer(buffer_id).instruction());
2029 }
2030
2031 // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
2032 // layout and propagate the change.
2033 while (!constraints.unconstrained_buffer_ids().empty()) {
2034 int unconstrained_count = constraints.unconstrained_buffer_ids().size();
2035
2036 // Arbitrarily pick the first unconstrained buffer and give it the default
2037 // layout (or the literal layout, in case of constants). By construction
2038 // unconstrained_buffers() has a stable sort based on LogicalBuffer::Id.
2039 const LogicalBuffer& buffer = points_to_analysis_->GetBuffer(
2040 *constraints.unconstrained_buffer_ids().begin());
2041 const HloInstruction* instruction = buffer.instruction();
2042 Layout new_layout =
2043 instruction->opcode() == HloOpcode::kConstant
2044 ? ShapeUtil::GetSubshape(instruction->literal().shape(),
2045 buffer.index())
2046 .layout()
2047 : GetUnconstrainedLayout(buffer);
2048 TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer,
2049 /*mandatory=*/false));
2050
2051 TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
2052
2053 // To verify progress has been made, check that the number of unconstrained
2054 // buffers has been reduced.
2055 CHECK_LT(constraints.unconstrained_buffer_ids().size(),
2056 unconstrained_count);
2057 }
2058 // All logical buffers should have constraints at this point. All that
2059 // remains is assign the constraints to the buffers and infer layouts for
2060 // aliased buffers.
2061 TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
2062
2063 // If the computation layout wasn't specified, now it is the time to compute
2064 // it according to the parameters and root instruction layouts.
2065 // This allows the first pass through this API to record the best flowing
2066 // layout to parameters and root instruction.
2067 if (computation_layout == nullptr) {
2068 TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
2069 }
2070
2071 // Record the layouts assigned for any communication ops in
2072 // channel_constraints so that they are constrained for future modules.
2073 if (channel_constraints != nullptr) {
2074 TF_RETURN_IF_ERROR(
2075 ConstrainChannelLayouts(computation, channel_constraints));
2076 }
2077
2078 // Copy the root instruction's result if its layout does not match the result
2079 // layout constraint.
2080 if (constraints.ResultLayout() != nullptr &&
2081 !constraints.ResultLayout()->MatchesLayoutInShape(
2082 computation->root_instruction()->shape(),
2083 /*minor_to_major_only=*/true)) {
2084 if (conditional_mismatch_.count(computation) > 0) {
2085 *FindOrDie(computation_layouts_, computation).mutable_result_layout() =
2086 FindOrDie(conditional_mismatch_, computation).result_layout();
2087 }
2088 TF_ASSIGN_OR_RETURN(
2089 HloInstruction * new_root,
2090 CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
2091 computation->root_instruction()));
2092 computation->set_root_instruction(new_root);
2093 }
2094 return Status::OK();
2095 }
2096
ConstrainChannelLayouts(HloComputation * computation,ChannelLayoutConstraints * channel_constraints)2097 Status LayoutAssignment::ConstrainChannelLayouts(
2098 HloComputation* computation,
2099 ChannelLayoutConstraints* channel_constraints) {
2100 auto get_channel_constraints = [&](const HloInstruction* instruction) {
2101 return IsHostSendRecv(instruction) ? &host_channel_constraints_
2102 : channel_constraints;
2103 };
2104 // We go through the kRecvDone before. These must either impose their layout,
2105 // or find a matching one already existing (ConstrainChannel() returns
2106 // nullptr).
2107 for (HloInstruction* instruction : computation->instructions()) {
2108 if (instruction->opcode() == HloOpcode::kRecvDone) {
2109 const Layout* layout =
2110 get_channel_constraints(instruction)
2111 ->ConstrainChannel(
2112 *instruction->channel_id(),
2113 ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
2114 TF_RET_CHECK(layout == nullptr)
2115 << instruction->ToString()
2116 << " cannot constrain layout as it was set to "
2117 << LayoutUtil::HumanString(*layout);
2118 }
2119 }
2120 // After that we go through the kSend. These are likely going to have a kCopy
2121 // as operand (otherwise we add it), so in case the constrained layout does
2122 // not match, we can change the kCopy layout (and the kSend one as well).
2123 for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
2124 if (instruction->opcode() == HloOpcode::kSend) {
2125 HloInstruction* operand = instruction->mutable_operand(0);
2126 get_channel_constraints(instruction)
2127 ->ConstrainChannel(*instruction->channel_id(),
2128 operand->shape().layout());
2129 } else if (instruction->IsCrossModuleAllReduce()) {
2130 get_channel_constraints(instruction)
2131 ->ConstrainChannel(instruction->channel_id().value(),
2132 instruction->shape().layout());
2133 }
2134 }
2135 return Status::OK();
2136 }
2137
PropagateMemorySpace(HloModule * module)2138 Status LayoutAssignment::PropagateMemorySpace(HloModule* module) {
2139 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module));
2140 for (const auto& buffer : alias_analysis->buffers()) {
2141 // First go through values to collect the memory spaces.
2142 int64_t buffer_memory_space = Layout::kDefaultMemorySpace;
2143 for (auto value : buffer.values()) {
2144 const Shape& defining_shape = value->defining_position().shape();
2145 int64_t memory_space = defining_shape.layout().memory_space();
2146 if (memory_space != Layout::kDefaultMemorySpace) {
2147 if (buffer_memory_space != Layout::kDefaultMemorySpace &&
2148 memory_space != buffer_memory_space) {
2149 return InternalError(
2150 "Buffer %d (%s) has conflicting memory spaces: %d and %d.",
2151 buffer.id(), value->ToShortString(), buffer_memory_space,
2152 memory_space);
2153 }
2154 buffer_memory_space = memory_space;
2155 }
2156 }
2157
2158 // If we encounter a memory space other than the default, then propagate all
2159 // the positions with the buffer's memory space.
2160 if (buffer_memory_space != Layout::kDefaultMemorySpace) {
2161 for (auto value : buffer.values()) {
2162 for (auto& position : value->positions()) {
2163 Shape* shape = ShapeUtil::GetMutableSubshape(
2164 position.instruction->mutable_shape(), position.index);
2165 shape->mutable_layout()->set_memory_space(buffer_memory_space);
2166 }
2167 }
2168 }
2169 }
2170 return Status::OK();
2171 }
2172
PropagateComputationLayouts(HloComputation * computation,ComputationLayout * computation_layout)2173 Status LayoutAssignment::PropagateComputationLayouts(
2174 HloComputation* computation, ComputationLayout* computation_layout) {
2175 ComputationLayout computed_computation_layout(
2176 computation->ComputeProgramShape(),
2177 /*ignore_layouts=*/false);
2178 for (int64_t i = 0; i < computed_computation_layout.parameter_count(); ++i) {
2179 ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
2180 bool needs_assign = false;
2181 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
2182 param_layout->shape(),
2183 [&](const Shape& subshape, const ShapeIndex& shape_index) {
2184 if (!ShapeUtil::IsLeafIndex(param_layout->shape(), shape_index)) {
2185 return Status::OK();
2186 }
2187 if (!subshape.has_layout()) {
2188 needs_assign = true;
2189 return Status::OK();
2190 }
2191 const auto& computed_subshape = ShapeUtil::GetSubshape(
2192 computed_computation_layout.parameter_shape(i), shape_index);
2193 if (subshape.layout() != computed_subshape.layout()) {
2194 return InternalError(
2195 "Assigned parameter shape %s does not match layout of "
2196 "computation shape: %s",
2197 computed_computation_layout.ToString(),
2198 computation_layout->ToString());
2199 }
2200 return Status::OK();
2201 }));
2202 if (needs_assign) {
2203 VLOG(4) << "Assigning layout to parameter " << i << " of computation "
2204 << computation->name() << ": "
2205 << computed_computation_layout.parameter_layout(i).ToString();
2206 *param_layout = computed_computation_layout.parameter_layout(i);
2207 }
2208 }
2209 ShapeLayout* result_layout = computation_layout->mutable_result_layout();
2210 if (!result_layout->LayoutIsSet()) {
2211 VLOG(4) << "Assigning result layout of computation " << computation->name()
2212 << ": " << computed_computation_layout.result_layout().ToString();
2213 *result_layout = computed_computation_layout.result_layout();
2214 } else {
2215 TF_RET_CHECK(
2216 Shape::Equal().IgnoreDynamicDimension().MinorToMajorOnlyInLayout()(
2217 computed_computation_layout.result_layout().shape(),
2218 result_layout->shape()));
2219 }
2220 return Status::OK();
2221 }
2222
Run(HloModule * module)2223 StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
2224 VLOG(2) << "Running layout assignment on module " << module->name();
2225 TF_RETURN_IF_ERROR(Init());
2226 call_graph_ = CallGraph::Build(module);
2227 auto computations = module->computations();
2228
2229 // Add copy to the operand of Send instructions, since we cannot call
2230 // SetOperandLayout on Send instructions as it aliases its input to the
2231 // output.
2232 //
2233 // TODO(b/68493863): Remove this once we can call SetOperandLayout() on the
2234 // operand buffers that aliases with the output.
2235 for (HloComputation* computation : module->computations()) {
2236 for (HloInstruction* instruction :
2237 computation->MakeInstructionPostOrder()) {
2238 if (instruction->opcode() == HloOpcode::kSend) {
2239 TF_RETURN_IF_ERROR(AddCopyForOperand(instruction, 0));
2240 }
2241 }
2242 }
2243
2244 // Clone Conditional computations with multiple callsites.
2245 for (HloComputation* computation : computations) {
2246 CallGraphNode& node = call_graph_->GetNode(computation);
2247 if (node.caller_callsites().size() == 1) {
2248 continue;
2249 }
2250 if (absl::c_none_of(node.caller_callsites(), [](CallSite caller) {
2251 return caller.instruction()->opcode() == HloOpcode::kConditional;
2252 })) {
2253 continue;
2254 }
2255 for (int64_t i = 0; i < node.caller_callsites().size() - 1; ++i) {
2256 HloInstruction* caller = node.caller_callsites()[i].instruction();
2257 if (caller->opcode() == HloOpcode::kConditional) {
2258 for (int64_t k = 0; k < caller->branch_count(); ++k) {
2259 if (computation == caller->branch_computation(k)) {
2260 caller->set_branch_computation(
2261 k, module->AddEmbeddedComputation(computation->Clone()));
2262 break;
2263 }
2264 }
2265 }
2266 }
2267 }
2268
2269 // Verify computation layout is sane.
2270 const HloComputation* entry = module->entry_computation();
2271 TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
2272 entry->num_parameters());
2273 for (int64_t i = 0; i < entry->num_parameters(); ++i) {
2274 TF_RET_CHECK(
2275 ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
2276 entry->parameter_instruction(i)->shape()));
2277 }
2278 TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
2279 entry->root_instruction()->shape()));
2280 // We do two passes. The first one we pass a nullptr ComputationLayout to
2281 // the RunOnComputation() calls (for non entry computations), and we register
2282 // the ComputationLayout which are naturally flowing in DFS fashion to the
2283 // parameters and root instruction.
2284 // Walking in DFS mode though, means that we can end up with incorrect layouts
2285 // when seen from an outer instruction, which has across-computation
2286 // constraints to impose.
2287 // For example, the kWhile instruction needs to enforce the same layouts for
2288 // the parameters and root of the body, as well as the condition parameters.
2289 // Similarly, the kConditional instruction needs to enforce the same layouts
2290 // for the root of the true and false computations.
2291 // So in the first pass, while allowing the layouts to flow to parameters and
2292 // root, we also fix up the eventually inconsistent ComputationLayout, which
2293 // will be then made mandatory by the second pass.
2294 for (int64_t i = 0; i < 2; ++i) {
2295 VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
2296 TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
2297 TF_ASSIGN_OR_RETURN(auto points_to_analysis,
2298 TuplePointsToAnalysis::Run(module));
2299 points_to_analysis_ = std::move(points_to_analysis);
2300 auto computations_to_work = module->MakeComputationPostOrder();
2301 // If the reverse_comptation_order_ flag is set, reverse the ordering of
2302 // traversing computations, to generate an alternative layout assignment.
2303 if (reverse_computation_order_ && !computations_to_work.empty()) {
2304 absl::c_reverse(computations_to_work);
2305
2306 VLOG(2) << "reversing traversal order for computation:";
2307 }
2308 for (auto* computation : computations_to_work) {
2309 if (computation->IsFusionComputation()) {
2310 continue;
2311 }
2312 if (computation == module->entry_computation()) {
2313 TF_RETURN_IF_ERROR(RunOnComputation(entry_computation_layout_,
2314 module->entry_computation(),
2315 channel_layout_constraints_));
2316 } else {
2317 ComputationLayout* computation_layout =
2318 (i == 1 && conditional_mismatch_.count(computation) == 0) ||
2319 (reverse_computation_order_ &&
2320 computation_layouts_.find(computation) !=
2321 computation_layouts_.end())
2322 ? &FindOrDie(computation_layouts_, computation)
2323 : nullptr;
2324 TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation,
2325 channel_layout_constraints_));
2326 }
2327 }
2328 VLOG(2) << "After running #" << i << ":" << module->ToString();
2329 }
2330 TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
2331 entry_computation_layout_));
2332
2333 TF_RETURN_IF_ERROR(PropagateMemorySpace(module));
2334
2335 TF_RETURN_IF_ERROR(CheckLayouts(module));
2336
2337 // All layouts are reset then reassigned by this pass.
2338 return true;
2339 }
2340
2341 /* static */
InstructionCanChangeLayout(const HloInstruction * instruction)2342 bool LayoutAssignment::InstructionCanChangeLayout(
2343 const HloInstruction* instruction) {
2344 switch (instruction->opcode()) {
2345 case HloOpcode::kAbs:
2346 case HloOpcode::kAdd:
2347 case HloOpcode::kAddDependency:
2348 case HloOpcode::kAnd:
2349 case HloOpcode::kAtan2:
2350 case HloOpcode::kBitcastConvert:
2351 case HloOpcode::kCeil:
2352 case HloOpcode::kClamp:
2353 case HloOpcode::kClz:
2354 case HloOpcode::kCompare:
2355 case HloOpcode::kComplex:
2356 case HloOpcode::kConcatenate:
2357 case HloOpcode::kConditional:
2358 case HloOpcode::kConvert:
2359 case HloOpcode::kCos:
2360 case HloOpcode::kAllGather:
2361 case HloOpcode::kAllGatherStart:
2362 case HloOpcode::kAllGatherDone:
2363 case HloOpcode::kAllToAll:
2364 case HloOpcode::kCollectivePermute:
2365 case HloOpcode::kDivide:
2366 case HloOpcode::kDynamicSlice:
2367 case HloOpcode::kDynamicUpdateSlice:
2368 case HloOpcode::kExp:
2369 case HloOpcode::kExpm1:
2370 case HloOpcode::kFft:
2371 case HloOpcode::kFloor:
2372 case HloOpcode::kImag:
2373 case HloOpcode::kIsFinite:
2374 case HloOpcode::kLog:
2375 case HloOpcode::kLog1p:
2376 case HloOpcode::kLogistic:
2377 case HloOpcode::kMap:
2378 case HloOpcode::kMaximum:
2379 case HloOpcode::kMinimum:
2380 case HloOpcode::kMultiply:
2381 case HloOpcode::kNegate:
2382 case HloOpcode::kNot:
2383 case HloOpcode::kOr:
2384 case HloOpcode::kXor:
2385 case HloOpcode::kPad:
2386 case HloOpcode::kPower:
2387 case HloOpcode::kReal:
2388 case HloOpcode::kReducePrecision:
2389 case HloOpcode::kReduceWindow:
2390 case HloOpcode::kRemainder:
2391 case HloOpcode::kReverse:
2392 case HloOpcode::kRoundNearestAfz:
2393 case HloOpcode::kRsqrt:
2394 case HloOpcode::kScatter:
2395 case HloOpcode::kSelect:
2396 case HloOpcode::kSelectAndScatter:
2397 case HloOpcode::kShiftLeft:
2398 case HloOpcode::kShiftRightArithmetic:
2399 case HloOpcode::kShiftRightLogical:
2400 case HloOpcode::kSign:
2401 case HloOpcode::kSin:
2402 case HloOpcode::kSlice:
2403 case HloOpcode::kSort:
2404 case HloOpcode::kSqrt:
2405 case HloOpcode::kCbrt:
2406 case HloOpcode::kSubtract:
2407 case HloOpcode::kTanh:
2408 case HloOpcode::kPopulationCount:
2409 case HloOpcode::kTriangularSolve:
2410 case HloOpcode::kCholesky:
2411 case HloOpcode::kTupleSelect:
2412 case HloOpcode::kWhile:
2413 case HloOpcode::kSetDimensionSize:
2414 // AllReduce is variadic so it needs to be careful to assign the same layout
2415 // to the corresponding input argument and Tuple index.
2416 case HloOpcode::kAllReduce:
2417 case HloOpcode::kReduceScatter:
2418 case HloOpcode::kAllReduceStart:
2419 case HloOpcode::kAllReduceDone:
2420 return false;
2421 case HloOpcode::kBatchNormGrad:
2422 case HloOpcode::kBatchNormInference:
2423 case HloOpcode::kBatchNormTraining:
2424 case HloOpcode::kBitcast:
2425 case HloOpcode::kBroadcast:
2426 case HloOpcode::kCall:
2427 case HloOpcode::kCollectivePermuteStart:
2428 case HloOpcode::kCollectivePermuteDone:
2429 case HloOpcode::kConstant:
2430 case HloOpcode::kConvolution:
2431 case HloOpcode::kCopy:
2432 case HloOpcode::kCopyStart:
2433 case HloOpcode::kCopyDone:
2434 case HloOpcode::kCustomCall:
2435 case HloOpcode::kDomain:
2436 case HloOpcode::kDot:
2437 case HloOpcode::kFusion:
2438 case HloOpcode::kGather:
2439 case HloOpcode::kGetTupleElement:
2440 case HloOpcode::kInfeed:
2441 case HloOpcode::kIota:
2442 case HloOpcode::kOutfeed:
2443 case HloOpcode::kParameter:
2444 case HloOpcode::kPartitionId:
2445 case HloOpcode::kRecv:
2446 case HloOpcode::kRecvDone:
2447 case HloOpcode::kReduce:
2448 case HloOpcode::kReplicaId:
2449 case HloOpcode::kReshape:
2450 case HloOpcode::kDynamicReshape:
2451 case HloOpcode::kRng:
2452 case HloOpcode::kRngBitGenerator:
2453 case HloOpcode::kRngGetAndUpdateState:
2454 case HloOpcode::kSend:
2455 case HloOpcode::kSendDone:
2456 case HloOpcode::kAfterAll:
2457 case HloOpcode::kTrace:
2458 case HloOpcode::kTranspose:
2459 case HloOpcode::kTuple:
2460 case HloOpcode::kGetDimensionSize:
2461 return true;
2462 }
2463 }
2464
InstructionCanChangeLayoutInstance(const HloInstruction * instruction)2465 bool LayoutAssignment::InstructionCanChangeLayoutInstance(
2466 const HloInstruction* instruction) {
2467 return InstructionCanChangeLayout(instruction);
2468 }
2469
2470 /* static */
IsAtMostRank1(const Shape & shape)2471 bool LayoutAssignment::IsAtMostRank1(const Shape& shape) {
2472 if (shape.IsArray()) {
2473 return shape.rank() <= 1;
2474 }
2475 return absl::c_all_of(shape.tuple_shapes(), [](const Shape& subshape) {
2476 return IsAtMostRank1(subshape);
2477 });
2478 }
2479
Init()2480 Status LayoutAssignment::Init() {
2481 computation_layouts_.clear();
2482 conditional_mismatch_.clear();
2483 *entry_computation_layout_ = saved_entry_computation_layout_;
2484 return Status::OK();
2485 }
2486
ClearPreviousPassSideEffects(HloModule * module)2487 Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
2488 VLOG(5) << "Clearing previous side effects";
2489 // Clear all the copies which have been added, and all the related
2490 // instructions (like GTE and tuples).
2491 int64_t removed_copies = 0;
2492 for (HloComputation* computation : module->computations()) {
2493 for (HloInstruction* instruction :
2494 computation->MakeInstructionPostOrder()) {
2495 if (instruction->opcode() == HloOpcode::kCopy &&
2496 added_copies_.contains(instruction)) {
2497 VLOG(5) << "Removing added copy: " << instruction->ToString();
2498 TF_RETURN_IF_ERROR(
2499 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
2500 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
2501 ++removed_copies;
2502 }
2503 }
2504 }
2505 added_copies_.clear();
2506 unconstrained_layout_instructions_.clear();
2507 if (removed_copies > 0) {
2508 TupleSimplifier tuple_simplifier;
2509 HloDCE dce;
2510 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
2511 TF_RETURN_IF_ERROR(dce.Run(module).status());
2512 call_graph_ = CallGraph::Build(module);
2513 }
2514 return Status::OK();
2515 }
2516
AddCopyForOperand(HloInstruction * instruction,int64_t operand_number)2517 Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
2518 int64_t operand_number) {
2519 HloInstruction* operand = instruction->mutable_operand(operand_number);
2520 if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
2521 HloInstruction* copy =
2522 instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
2523 operand->shape(), HloOpcode::kCopy, operand));
2524 SetupCopiedInstruction(*operand, copy, {});
2525 LayoutUtil::ClearLayout(copy->mutable_shape());
2526 TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
2527 }
2528 return Status::OK();
2529 }
2530
2531 } // namespace xla
2532