1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <initializer_list>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/test_helpers.h"
39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45
46 namespace xla {
47 namespace {
48
49 namespace m = xla::match;
50 using ::testing::ElementsAre;
51
52 class LayoutAssignmentTest : public HloTestBase {
53 protected:
AssignLayouts(HloModule * m,ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints=nullptr)54 void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
55 ChannelLayoutConstraints* channel_constraints = nullptr) {
56 LayoutAssignment layout_assignment(
57 entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
58 /*channel_constraints=*/channel_constraints);
59 EXPECT_IS_OK(layout_assignment.Run(m).status());
60 }
61
LayoutOf(HloModule * m,absl::string_view name)62 std::vector<int64> LayoutOf(HloModule* m, absl::string_view name) {
63 auto minor_to_major =
64 FindInstruction(m, name)->shape().layout().minor_to_major();
65 return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
66 }
67
ExpectLayoutIs(const Shape & shape,absl::Span<const int64> minor_to_major)68 void ExpectLayoutIs(const Shape& shape,
69 absl::Span<const int64> minor_to_major) {
70 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
71 EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
72 << "Expected layout " << expected << ", actual " << shape.layout();
73 }
74
ExpectTupleLayoutIs(const Shape & shape,std::initializer_list<absl::Span<const int64>> minor_to_majors)75 void ExpectTupleLayoutIs(
76 const Shape& shape,
77 std::initializer_list<absl::Span<const int64>> minor_to_majors) {
78 int i = 0;
79 for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
80 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
81 const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
82 EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
83 << "Expected tuple element " << i << " layout " << expected
84 << ", actual " << actual;
85 ++i;
86 }
87 }
88 };
89
TEST_F(LayoutAssignmentTest,ComputationLayout)90 TEST_F(LayoutAssignmentTest, ComputationLayout) {
91 // Verify the layouts of the root and parameter instructions of a computation
92 // match the ComputationLayout for two different layouts.
93 std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
94 for (auto& minor_to_major : minor_to_majors) {
95 auto builder = HloComputation::Builder(TestName());
96 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
97 auto param0 = builder.AddInstruction(
98 HloInstruction::CreateParameter(0, ashape, "param0"));
99 auto param1 = builder.AddInstruction(
100 HloInstruction::CreateParameter(1, ashape, "param1"));
101 auto add = builder.AddInstruction(
102 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
103 auto m = CreateNewVerifiedModule();
104 HloComputation* computation = m->AddEntryComputation(builder.Build());
105
106 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
107 Shape shape(ashape);
108 *shape.mutable_layout() = layout;
109 const ShapeLayout shape_layout(shape);
110
111 ComputationLayout computation_layout(computation->ComputeProgramShape());
112 *computation_layout.mutable_parameter_layout(0) = shape_layout;
113 *computation_layout.mutable_parameter_layout(1) = shape_layout;
114 *computation_layout.mutable_result_layout() = shape_layout;
115 AssignLayouts(m.get(), &computation_layout);
116 EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
117 EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
118 EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
119 }
120 }
121
TEST_F(LayoutAssignmentTest,ComputationLayoutMixedLayout)122 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
123 // Verify the layouts of the root and parameter instructions of a computation
124 // match the ComputationLayout which has mixed layout.
125 auto builder = HloComputation::Builder(TestName());
126 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
127 auto param0 = builder.AddInstruction(
128 HloInstruction::CreateParameter(0, ashape, "param0"));
129 auto param1 = builder.AddInstruction(
130 HloInstruction::CreateParameter(1, ashape, "param1"));
131 builder.AddInstruction(
132 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
133 auto m = CreateNewVerifiedModule();
134 HloComputation* computation = m->AddEntryComputation(builder.Build());
135
136 Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
137 Shape col_major_shape(ashape);
138 *col_major_shape.mutable_layout() = col_major_layout;
139 const ShapeLayout col_major(col_major_shape);
140
141 Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
142 Shape row_major_shape(ashape);
143 *row_major_shape.mutable_layout() = row_major_layout;
144 const ShapeLayout row_major(row_major_shape);
145
146 ComputationLayout computation_layout(computation->ComputeProgramShape());
147 *computation_layout.mutable_parameter_layout(0) = col_major;
148 *computation_layout.mutable_parameter_layout(1) = row_major;
149 *computation_layout.mutable_result_layout() = col_major;
150
151 AssignLayouts(m.get(), &computation_layout);
152 EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
153 EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
154 EXPECT_TRUE(LayoutUtil::Equal(
155 col_major_layout, computation->root_instruction()->shape().layout()));
156 }
157
TEST_F(LayoutAssignmentTest,FusionInstruction)158 TEST_F(LayoutAssignmentTest, FusionInstruction) {
159 // Verify that the layout of the fused parameters in a fusion instruction
160 // match that of the fusion operands. Other fused instructions should have no
161 // layout.
162 std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
163 for (auto& minor_to_major : minor_to_majors) {
164 auto builder = HloComputation::Builder(TestName());
165 auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
166 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
167 auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
168 {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
169 Shape ashape = constant_literal1.shape();
170
171 auto constant1 = builder.AddInstruction(
172 HloInstruction::CreateConstant(std::move(constant_literal1)));
173 auto constant2 = builder.AddInstruction(
174 HloInstruction::CreateConstant(std::move(constant_literal2)));
175 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
176 ashape, HloOpcode::kAdd, constant1, constant2));
177 auto negate1 = builder.AddInstruction(
178 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
179 auto negate2 = builder.AddInstruction(
180 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
181
182 auto m = CreateNewVerifiedModule();
183 HloComputation* computation = m->AddEntryComputation(builder.Build());
184
185 auto fusion = computation->CreateFusionInstruction(
186 {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
187
188 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
189 Shape shape(ashape);
190 *shape.mutable_layout() = layout;
191 const ShapeLayout shape_layout(shape);
192
193 ComputationLayout computation_layout(computation->ComputeProgramShape());
194 *computation_layout.mutable_result_layout() = shape_layout;
195
196 AssignLayouts(m.get(), &computation_layout);
197
198 EXPECT_TRUE(LayoutUtil::Equal(
199 layout, fusion->fused_parameter(0)->shape().layout()));
200 EXPECT_TRUE(LayoutUtil::Equal(
201 layout, fusion->fused_parameter(1)->shape().layout()));
202 EXPECT_TRUE(LayoutUtil::Equal(
203 layout, fusion->fused_expression_root()->shape().layout()));
204
205 // Inner fused node should not have layout.
206 EXPECT_FALSE(LayoutUtil::HasLayout(
207 fusion->fused_expression_root()->operand(0)->shape()));
208 }
209 }
210
TEST_F(LayoutAssignmentTest,TupleLayout)211 TEST_F(LayoutAssignmentTest, TupleLayout) {
212 // Verify the layouts of a tuple are assigned properly (the element layouts
213 // match their source).
214 auto builder = HloComputation::Builder(TestName());
215 auto constant0 = builder.AddInstruction(
216 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
217 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
218 auto constant1 = builder.AddInstruction(
219 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
220 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
221 auto tuple = builder.AddInstruction(
222 HloInstruction::CreateTuple({constant0, constant1}));
223
224 // To avoid having to construct a tuple layout in the ComputationLayout below,
225 // make the result of the instruction be an array.
226 auto get_element0 = builder.AddInstruction(
227 HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
228 auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
229 constant0->shape(), HloOpcode::kNegate, get_element0));
230
231 auto m = CreateNewVerifiedModule();
232 m->AddEntryComputation(builder.Build());
233
234 ComputationLayout computation_layout(
235 m->entry_computation()->ComputeProgramShape());
236
237 AssignLayouts(m.get(), &computation_layout);
238
239 EXPECT_TRUE(
240 LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
241
242 EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
243 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
244 negate->shape(), computation_layout.result_layout().shape()));
245 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
246 ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
247 }
248
TEST_F(LayoutAssignmentTest,TupleSelect)249 TEST_F(LayoutAssignmentTest, TupleSelect) {
250 // Verify layouts of a select with tuple operands is assigned properly.
251 auto builder = HloComputation::Builder(TestName());
252 auto constant0 = builder.AddInstruction(
253 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
254 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
255 auto constant1 = builder.AddInstruction(
256 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
257 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
258 auto tuple0 = builder.AddInstruction(
259 HloInstruction::CreateTuple({constant0, constant1}));
260 auto tuple1 = builder.AddInstruction(
261 HloInstruction::CreateTuple({constant0, constant1}));
262
263 auto pred = builder.AddInstruction(
264 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
265
266 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
267 tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
268
269 auto m = CreateNewVerifiedModule();
270 m->AddEntryComputation(builder.Build());
271
272 ComputationLayout computation_layout(
273 m->entry_computation()->ComputeProgramShape());
274 Shape result_shape =
275 ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
276 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
277 result_shape));
278
279 AssignLayouts(m.get(), &computation_layout);
280
281 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
282 }
283
TEST_F(LayoutAssignmentTest,ConflictingLayoutTuple)284 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
285 // Construct following computation which has conflicting layouts for two
286 // elements of a tuple which share the same source logicalb buffer:
287 //
288 // %constant = Constant(...)
289 // %inner_tuple = Tuple(%constant)
290 // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
291 //
292 // Result layout col-major for the first element and row-major for the
293 // second. This results in the conflict where the element of the inner_tuple
294 // needs to be both col and row major. This is resolved by deep-copying the
295 // tuple and assigning the layouts of the copied arrays as needed.
296 auto builder = HloComputation::Builder(TestName());
297 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
298 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
299 auto inner_tuple =
300 builder.AddInstruction(HloInstruction::CreateTuple({constant}));
301 auto nested_tuple = builder.AddInstruction(
302 HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
303
304 auto m = CreateNewVerifiedModule();
305 m->AddEntryComputation(builder.Build());
306
307 ComputationLayout computation_layout(
308 m->entry_computation()->ComputeProgramShape());
309 Shape result_shape = nested_tuple->shape();
310 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
311 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
312 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
313 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
314 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
315 result_shape));
316
317 LayoutAssignment layout_assignment(&computation_layout);
318 AssignLayouts(m.get(), &computation_layout);
319
320 // Layout assignment should have deep copied the result of the computation to
321 // address the layout conflict. This results in several Tuple() and
322 // GetTupleElement() instructions. Running algebraic simplification should
323 // clean up the code to something like:
324 //
325 // %constant = Constant(...) layout={1,0}
326 // %tuple.0 = Tuple(%constant) layout=({1,0})
327 // %copy = Copy(%constant) layout={0,1} # layout transposed
328 // %tuple.1 = Tuple(%copy) layout=({0,1})
329 // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
330 //
331 AlgebraicSimplifierOptions options(
332 [](const Shape&, const Shape&) { return false; });
333 options.set_is_layout_sensitive(true);
334 EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
335 HloInstruction* root = m->entry_computation()->root_instruction();
336 // Verify layout of the root and the root's operands.
337 EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
338 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
339 root->operand(0)->shape()));
340 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
341 root->operand(1)->shape()));
342
343 // Verify the structure of the HLO graph.
344 EXPECT_THAT(root,
345 GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
346 m::Tuple(m::Copy(m::Op().Is(constant))))));
347 }
348
TEST_F(LayoutAssignmentTest,ElementwiseAndReshape)349 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
350 // param -> log -> reshape -> tanh
351 auto builder = HloComputation::Builder(TestName());
352 Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
353 Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
354 auto param = builder.AddInstruction(
355 HloInstruction::CreateParameter(0, ashape, "param"));
356 auto log = builder.AddInstruction(
357 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
358 auto reshape =
359 builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
360 auto tanh = builder.AddInstruction(
361 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
362
363 auto m = CreateNewVerifiedModule();
364 HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
365
366 Shape ashape_with_layout(ashape);
367 Shape bshape_with_layout(bshape);
368 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
369 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
370
371 ComputationLayout computation_layout(computation->ComputeProgramShape());
372 *computation_layout.mutable_parameter_layout(0) =
373 ShapeLayout(ashape_with_layout);
374 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
375 AssignLayouts(m.get(), &computation_layout);
376
377 auto log_minor_to_major =
378 AsInt64Slice(log->shape().layout().minor_to_major());
379 EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
380 PositionInContainer(log_minor_to_major, 2));
381
382 auto reshape_minor_to_major =
383 AsInt64Slice(reshape->shape().layout().minor_to_major());
384 EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
385 PositionInContainer(reshape_minor_to_major, 2));
386 }
387
388 // Test whether LayoutAssignment assigns layouts to elementwise operations to
389 // keep linear indices valid across them, and to transpositions to make them
390 // bitcasts.
TEST_F(LayoutAssignmentTest,ElementwiseAndTranspose)391 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
392 // param -> log -> transpose -> tanh
393 auto builder = HloComputation::Builder(TestName());
394 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
395 Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
396 auto param = builder.AddInstruction(
397 HloInstruction::CreateParameter(0, ashape, "param"));
398 auto log = builder.AddInstruction(
399 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
400 auto transpose = builder.AddInstruction(
401 HloInstruction::CreateTranspose(bshape, log, {1, 0}));
402 auto tanh = builder.AddInstruction(
403 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
404 auto m = CreateNewVerifiedModule();
405 auto computation = m->AddEntryComputation(builder.Build(tanh));
406
407 Shape ashape_with_layout(ashape);
408 Shape bshape_with_layout(bshape);
409 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
410 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
411
412 ComputationLayout computation_layout(computation->ComputeProgramShape());
413 *computation_layout.mutable_parameter_layout(0) =
414 ShapeLayout(ashape_with_layout);
415 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
416 AssignLayouts(m.get(), &computation_layout);
417
418 EXPECT_TRUE(
419 LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
420 EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
421 transpose->shape().layout()));
422 EXPECT_TRUE(
423 LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
424 }
425
426 // Test whether LayoutAssignment assigns layouts to transpositions to make them
427 // bitcasts.
TEST_F(LayoutAssignmentTest,BroadcastAndTranspose)428 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
429 // param -> broadcast -> transpose
430 auto builder = HloComputation::Builder(TestName());
431 Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
432 Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
433 Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
434 auto param = builder.AddInstruction(
435 HloInstruction::CreateParameter(0, ashape, "param"));
436 auto broadcast = builder.AddInstruction(
437 HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
438 auto transpose = builder.AddInstruction(
439 HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
440 auto m = CreateNewVerifiedModule();
441 HloComputation* computation =
442 m->AddEntryComputation(builder.Build(transpose));
443
444 Shape input_shape_with_layout(ashape);
445 Shape output_shape_with_layout(cshape);
446 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
447 *output_shape_with_layout.mutable_layout() =
448 LayoutUtil::MakeLayout({2, 1, 0});
449
450 ComputationLayout computation_layout(computation->ComputeProgramShape());
451 *computation_layout.mutable_parameter_layout(0) =
452 ShapeLayout(input_shape_with_layout);
453 *computation_layout.mutable_result_layout() =
454 ShapeLayout(output_shape_with_layout);
455 AssignLayouts(m.get(), &computation_layout);
456
457 EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
458 ElementsAre(0, 1, 2));
459 }
460
TEST_F(LayoutAssignmentTest,ReshapeOperandHasMultipleUsers)461 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
462 // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
463 // \ /
464 // \-> tanh[3x4] -> broadcast2[2x3x4] -/
465 //
466 // The layout of `transpose` is set to {1,0} because it provides a buffer to
467 // the computation result which has a fixed layout.. Therefore, `broadcast`
468 // (the operand of transpose) is expected to have layout {0,1} so that the
469 // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
470 // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
471 Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
472 Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
473 Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
474 Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
475
476 auto builder = HloComputation::Builder(TestName());
477 auto param = builder.AddInstruction(
478 HloInstruction::CreateParameter(0, f32_4, "param"));
479 auto broadcast = builder.AddInstruction(
480 HloInstruction::CreateBroadcast(f32_34, param, {1}));
481 auto transpose = builder.AddInstruction(
482 HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
483 auto tanh = builder.AddInstruction(
484 HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
485 auto broadcast2 = builder.AddInstruction(
486 HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
487 auto tuple = builder.AddInstruction(
488 HloInstruction::CreateTuple({transpose, broadcast2}));
489 auto m = CreateNewVerifiedModule();
490 HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
491
492 ComputationLayout computation_layout(computation->ComputeProgramShape());
493 Shape param_shape_with_layout(f32_4);
494 Shape transpose_shape_with_layout(f32_43);
495 Shape broadcast2_shape_with_layout(f32_234);
496 *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
497 *transpose_shape_with_layout.mutable_layout() =
498 LayoutUtil::MakeLayout({1, 0});
499 *broadcast2_shape_with_layout.mutable_layout() =
500 LayoutUtil::MakeLayout({2, 1, 0});
501
502 *computation_layout.mutable_parameter_layout(0) =
503 ShapeLayout(param_shape_with_layout);
504 *computation_layout.mutable_result_layout() =
505 ShapeLayout(ShapeUtil::MakeTupleShape(
506 {transpose_shape_with_layout, broadcast2_shape_with_layout}));
507 AssignLayouts(m.get(), &computation_layout);
508
509 EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
510 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
511 EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
512 }
513
514 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
515 public:
OperandsMustBeTheSameLayoutAssignment(ComputationLayout * entry_computation_layout)516 explicit OperandsMustBeTheSameLayoutAssignment(
517 ComputationLayout* entry_computation_layout)
518 : LayoutAssignment(entry_computation_layout) {}
519
520 protected:
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)521 Status PropagateBufferConstraint(
522 const BufferLayoutConstraint& buffer_constraint,
523 LayoutConstraints* constraints) override {
524 const LogicalBuffer& buffer = buffer_constraint.buffer();
525 const HloInstruction* instruction = buffer.instruction();
526
527 // Force the operands' layout to the output layout.
528 for (int64 operand_no = 0; operand_no < instruction->operand_count();
529 ++operand_no) {
530 const HloInstruction* operand = instruction->operand(operand_no);
531 if (instruction->shape().rank() != operand->shape().rank()) {
532 continue;
533 }
534 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
535 buffer_constraint.layout(), instruction, operand_no,
536 /*mandatory=*/true));
537 }
538 return PropagateBufferConstraintToUses(buffer_constraint, constraints);
539 }
540 };
541
TEST_F(LayoutAssignmentTest,MakeOperandsTheSame)542 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
543 // param0 -> concatenate -> reshape
544 // param1 -^
545 auto builder = HloComputation::Builder(TestName());
546 Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
547 Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
548 Shape cshape = ShapeUtil::MakeShape(F32, {100});
549 auto param0 = builder.AddInstruction(
550 HloInstruction::CreateParameter(0, ashape, "param"));
551 auto param1 = builder.AddInstruction(
552 HloInstruction::CreateParameter(1, ashape, "param"));
553 auto concatenate = builder.AddInstruction(
554 HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
555 auto reshape = builder.AddInstruction(
556 HloInstruction::CreateReshape(cshape, concatenate));
557 auto m = CreateNewVerifiedModule();
558 HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
559
560 Shape param0_shape_with_layout(ashape);
561 Shape param1_shape_with_layout(ashape);
562 *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
563 *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
564
565 ComputationLayout computation_layout(computation->ComputeProgramShape());
566 *computation_layout.mutable_parameter_layout(0) =
567 ShapeLayout(param0_shape_with_layout);
568 *computation_layout.mutable_parameter_layout(1) =
569 ShapeLayout(param1_shape_with_layout);
570 OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
571 EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
572
573 EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
574 EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
575 ElementsAre(1, 0));
576 EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(),
577 ElementsAre(1, 0));
578 EXPECT_THAT(concatenate->shape().layout().minor_to_major(),
579 ElementsAre(1, 0));
580 }
581
582 // Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest,TransposeToBitcastFromOperand)583 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
584 auto builder = HloComputation::Builder(TestName());
585 Shape input_shape_with_layout =
586 ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
587 auto param = builder.AddInstruction(
588 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
589 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
590 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
591 auto m = CreateNewVerifiedModule();
592 HloComputation* computation =
593 m->AddEntryComputation(builder.Build(transpose));
594 ComputationLayout computation_layout(computation->ComputeProgramShape());
595 AssignLayouts(m.get(), &computation_layout);
596 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
597 transpose->shape(), {2, 3, 0, 1}));
598 }
599 // Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest,TransposeToBitcastToUser)600 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
601 auto builder = HloComputation::Builder(TestName());
602 Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
603 auto constant = builder.AddInstruction(
604 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
605 auto broadcast = builder.AddInstruction(
606 HloInstruction::CreateBroadcast(input_shape, constant, {}));
607 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
608 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
609 auto m = CreateNewVerifiedModule();
610 HloComputation* computation =
611 m->AddEntryComputation(builder.Build(transpose));
612 ComputationLayout computation_layout(computation->ComputeProgramShape());
613 AssignLayouts(m.get(), &computation_layout);
614 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
615 transpose->shape(), {2, 3, 0, 1}));
616 }
617
618 // TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,TransposeIsBitcastFail)619 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
620 auto builder = HloComputation::Builder(TestName());
621 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
622 Shape input_shape_with_layout(input_shape);
623 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
624 auto param = builder.AddInstruction(
625 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
626 auto hlo = builder.AddInstruction(
627 HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
628 // Clear the default layout assigned to the instruction.
629 LayoutUtil::ClearLayout(hlo->mutable_shape());
630 EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
631 hlo->shape(), hlo->dimensions()),
632 "LayoutUtil::HasLayout");
633 }
634
635 // ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,ReshapeIsBitcastFail)636 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
637 auto builder = HloComputation::Builder(TestName());
638 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
639 Shape input_shape_with_layout(input_shape);
640 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
641 auto param = builder.AddInstruction(
642 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
643 auto hlo =
644 builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
645 // Clear the default layout assigned to the instruction.
646 LayoutUtil::ClearLayout(hlo->mutable_shape());
647 EXPECT_DEATH(
648 ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
649 "LayoutUtil::HasLayout");
650 }
651
652 // Check that the computation below doesn't crash the compiler.
653 //
654 // Within a fusion computation, only the parameters and result get assigned a
655 // layout. When we run the algebraic simplifier on this computation post layout
656 // assignment, it should not call TransposeIsBitcast on the `transpose` node
657 // inside the fusion computation as TransposeIsBitcast checks both input_shape
658 // and output_shape have layouts.
TEST_F(LayoutAssignmentTest,TransposeWithinFusionDoesNotCrash)659 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
660 const char* module_str = R"(
661 HloModule test_module
662
663 fused_computation {
664 param_1 = f32[2,2,2]{2,1,0} parameter(1)
665 transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
666 reduce_1 = f32[] parameter(0)
667 broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
668 ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
669 }
670
671 ENTRY entry_computation {
672 fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
673 reduce.1 = f32[] parameter(0)
674 fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
675 ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
676 }
677 )";
678
679 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
680 ParseAndReturnVerifiedModule(module_str));
681 std::unique_ptr<HloModule> compiled_module =
682 backend()
683 .compiler()
684 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
685 /*device_allocator=*/nullptr)
686 .ConsumeValueOrDie();
687
688 EXPECT_EQ(Status::OK(), backend()
689 .compiler()
690 ->RunBackend(std::move(compiled_module),
691 backend().default_stream_executor(),
692 /*device_allocator=*/nullptr)
693 .status());
694 }
695
696 // A GTE inside of a fusion node inherits the layout of its operand (which
697 // should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest,GTEInheritsLayoutFromOperand)698 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
699 const char* module_str = R"(
700 HloModule test_module
701
702 fused_computation {
703 fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
704 gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
705 gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
706 gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
707 gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
708 add = f32[2,2,2] add(gte1a, gte1b)
709 ROOT fresult = f32[2,2,2] add(gte0, add)
710 }
711
712 ENTRY entry_computation {
713 param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
714 ROOT fusion =
715 f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
716 }
717 )";
718
719 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
720 ParseAndReturnVerifiedModule(module_str));
721 ComputationLayout computation_layout(
722 m->entry_computation()->ComputeProgramShape());
723 Shape param_shape = ShapeUtil::MakeTupleShape(
724 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
725 ShapeUtil::MakeTupleShape({
726 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
727 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
728 })});
729 TF_ASSERT_OK(
730 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
731 param_shape));
732 computation_layout.mutable_result_layout()->ResetLayout(
733 LayoutUtil::MakeLayout({2, 1, 0}));
734 AssignLayouts(m.get(), &computation_layout);
735
736 EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
737 EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
738 EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
739 EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
740 EXPECT_THAT(FindInstruction(m.get(), "gte1")
741 ->shape()
742 .tuple_shapes(0)
743 .layout()
744 .minor_to_major(),
745 ElementsAre(1, 2, 0));
746 EXPECT_THAT(FindInstruction(m.get(), "gte1")
747 ->shape()
748 .tuple_shapes(1)
749 .layout()
750 .minor_to_major(),
751 ElementsAre(2, 0, 1));
752 }
753
TEST_F(LayoutAssignmentTest,ConditionalAsymmetricLayout)754 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
755 auto builder = HloComputation::Builder(TestName());
756 auto m = CreateNewVerifiedModule();
757 Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
758 Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
759 Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
760
761 auto param0 = builder.AddInstruction(
762 HloInstruction::CreateParameter(0, shape, "param0"));
763 auto param1 = builder.AddInstruction(
764 HloInstruction::CreateParameter(1, shape, "param1"));
765 auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
766 2, ShapeUtil::MakeShape(PRED, {}), "param2"));
767 auto tuple =
768 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
769
770 auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
771 {
772 auto param = true_builder.AddInstruction(
773 HloInstruction::CreateParameter(0, tshape, "param"));
774 auto gte0 = true_builder.AddInstruction(
775 HloInstruction::CreateGetTupleElement(shape, param, 0));
776 auto gte1 = true_builder.AddInstruction(
777 HloInstruction::CreateGetTupleElement(shape, param, 1));
778 auto add = true_builder.AddInstruction(
779 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
780 true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
781 }
782 HloComputation* true_computation =
783 m->AddEmbeddedComputation(true_builder.Build());
784
785 auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
786 {
787 Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
788 false_builder.AddInstruction(
789 HloInstruction::CreateParameter(0, tshape, "param"));
790 // Using infeed as layout assignment does not mess up with it.
791 auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
792 auto infeed = false_builder.AddInstruction(
793 HloInstruction::CreateInfeed(xshape, token, ""));
794 auto infeed_data = false_builder.AddInstruction(
795 HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
796 false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
797 }
798 HloComputation* false_computation =
799 m->AddEmbeddedComputation(false_builder.Build());
800 builder.AddInstruction(HloInstruction::CreateConditional(
801 result_tshape, pred, tuple, true_computation, tuple, false_computation));
802
803 HloComputation* computation = m->AddEntryComputation(builder.Build());
804 ComputationLayout computation_layout(computation->ComputeProgramShape());
805
806 AssignLayouts(m.get(), &computation_layout);
807
808 const HloInstruction* true_root = true_computation->root_instruction();
809 const HloInstruction* false_root = false_computation->root_instruction();
810 EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
811 EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
812
813 const HloInstruction* true_result = true_root->operand(0);
814 const HloInstruction* false_result = false_root->operand(0);
815 EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
816 false_result->shape().layout()));
817 EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
818 }
819
TEST_F(LayoutAssignmentTest,InternalErrorOnBitcast)820 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
821 auto builder = HloComputation::Builder(TestName());
822 auto constant0 = builder.AddInstruction(
823 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
824 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
825 builder.AddInstruction(HloInstruction::CreateUnary(
826 constant0->shape(), HloOpcode::kBitcast, constant0));
827 auto m = CreateNewVerifiedModule();
828 m->AddEntryComputation(builder.Build());
829
830 ComputationLayout computation_layout(
831 m->entry_computation()->ComputeProgramShape());
832 LayoutAssignment layout_assignment(&computation_layout);
833 Status error_status = layout_assignment.Run(m.get()).status();
834 EXPECT_FALSE(error_status.ok());
835 EXPECT_THAT(
836 error_status.error_message(),
837 ::testing::HasSubstr(
838 "Unexpected bitcast operation seen during layout assignment"));
839 }
840
TEST_F(LayoutAssignmentTest,ChannelLayoutMismatch)841 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
842 // Pin non matching layouts to parameter and root.
843 const char* module_str = R"(
844 HloModule test_module
845
846 ENTRY entry_computation {
847 param = (f32[2,2]) parameter(0)
848 gte = f32[2,2] get-tuple-element(param), index=0
849 token0 = token[] after-all()
850 recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
851 recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
852 sharding={maximal device=1}
853 ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
854 send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
855 sharding={maximal device=0}
856 send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
857 }
858 )";
859
860 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
861 ParseAndReturnVerifiedModule(module_str));
862 ComputationLayout computation_layout(
863 m->entry_computation()->ComputeProgramShape());
864 Shape param_shape = ShapeUtil::MakeTupleShape(
865 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
866 TF_ASSERT_OK(
867 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
868 param_shape));
869 computation_layout.mutable_result_layout()->ResetLayout(
870 LayoutUtil::MakeLayout({1, 0}));
871
872 ChannelLayoutConstraints channel_constraints;
873 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
874
875 EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
876 EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0));
877 EXPECT_TRUE(ShapeUtil::Equal(
878 ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}),
879 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
880 }
881
TEST_F(LayoutAssignmentTest,AllReduceLayoutMissmatch)882 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
883 // Pin non matching layouts to parameter and root.
884 const char* module_str = R"(
885 HloModule test_module
886
887 add {
888 lhs = f32[] parameter(0)
889 rhs = f32[] parameter(1)
890 ROOT add = f32[] add(lhs, rhs)
891 }
892
893 ENTRY entry_computation {
894 param = (f32[2,2]) parameter(0)
895 gte = f32[2,2] get-tuple-element(param), index=0
896 ar.0 = f32[2,2] all-reduce(gte),
897 all_reduce_id=1, replica_groups={{0}}, to_apply=add,
898 sharding={maximal device=0}
899 const = f32[2,2] constant({{0,1},{2,3}})
900 ROOT ar.1 = f32[2,2] all-reduce(const),
901 all_reduce_id=1, replica_groups={{0}}, to_apply=add,
902 sharding={maximal device=1}
903 })";
904 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
905 ParseAndReturnVerifiedModule(module_str));
906 ComputationLayout computation_layout(
907 m->entry_computation()->ComputeProgramShape());
908 Shape param_shape = ShapeUtil::MakeTupleShape(
909 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
910 TF_ASSERT_OK(
911 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
912 param_shape));
913 computation_layout.mutable_result_layout()->ResetLayout(
914 LayoutUtil::MakeLayout({1, 0}));
915
916 ChannelLayoutConstraints channel_constraints;
917 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
918
919 EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
920 EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
921 EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
922 const HloInstruction* root = m->entry_computation()->root_instruction();
923 EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
924 }
925
TEST_F(LayoutAssignmentTest,CopySliceOperandToAvoidImplicitLayoutChange)926 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
927 const char* module_str = R"(
928 HloModule CopySliceOperandToAvoidImplicitLayoutChange
929
930 ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
931 par0 = f32[3,4]{1,0} parameter(0)
932 par1 = f32[4,5]{0,1} parameter(1)
933 slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
934 ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
935 }
936 )";
937
938 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
939 ParseAndReturnVerifiedModule(module_str));
940 auto compiled_module =
941 backend()
942 .compiler()
943 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
944 /*device_allocator=*/nullptr)
945 .ConsumeValueOrDie();
946 HloInstruction* root =
947 compiled_module->entry_computation()->root_instruction();
948 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
949 EXPECT_THAT(
950 root,
951 GmockMatch(m::Add(
952 m::Parameter(),
953 m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
954 }
955
TEST_F(LayoutAssignmentTest,CopyDSliceOperandToAvoidImplicitLayoutChange)956 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
957 const char* module_str = R"(
958 HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
959
960 ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
961 par0 = f32[3,4]{1,0} parameter(0)
962 par1 = f32[4,5]{0,1} parameter(1)
963 par2 = s32[] parameter(2)
964 par3 = s32[] parameter(3)
965 dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
966 ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
967 }
968 )";
969
970 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
971 ParseAndReturnVerifiedModule(module_str));
972 auto compiled_module =
973 backend()
974 .compiler()
975 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
976 /*device_allocator=*/nullptr)
977 .ConsumeValueOrDie();
978 HloInstruction* root =
979 compiled_module->entry_computation()->root_instruction();
980 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
981 EXPECT_THAT(root,
982 GmockMatch(m::Add(
983 m::Parameter(),
984 m::DynamicSlice(
985 m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
986 m::Parameter(2), m::Parameter(3)))));
987 }
988
TEST_F(LayoutAssignmentTest,CopyConcatOperandToAvoidImplicitLayoutChange)989 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
990 const char* module_str = R"(
991 HloModule CopyConcatOperandToAvoidImplicitLayoutChange
992
993 ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
994 par0 = f32[3,8]{1,0} parameter(0)
995 par1 = f32[3,5]{0,1} parameter(1)
996 par2 = f32[3,3]{1,0} parameter(2)
997 concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
998 dimensions={1}
999 ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
1000 }
1001 )";
1002
1003 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1004 ParseAndReturnVerifiedModule(module_str));
1005 auto compiled_module =
1006 backend()
1007 .compiler()
1008 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1009 /*device_allocator=*/nullptr)
1010 .ConsumeValueOrDie();
1011 HloInstruction* root =
1012 compiled_module->entry_computation()->root_instruction();
1013 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
1014 EXPECT_THAT(
1015 root,
1016 GmockMatch(m::Add(
1017 m::Parameter(),
1018 m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
1019 m::Parameter(2)))));
1020 }
1021
TEST_F(LayoutAssignmentTest,ConvolutionOperandWithImplicitLayoutChangeNotCopied)1022 TEST_F(LayoutAssignmentTest,
1023 ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
1024 const char* module_str = R"(
1025 HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
1026
1027 ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
1028 par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
1029 par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
1030 ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
1031 window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
1032 feature_group_count=1
1033 }
1034 )";
1035
1036 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1037 ParseAndReturnVerifiedModule(module_str));
1038 auto compiled_module =
1039 backend()
1040 .compiler()
1041 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1042 /*device_allocator=*/nullptr)
1043 .ConsumeValueOrDie();
1044 HloInstruction* root =
1045 compiled_module->entry_computation()->root_instruction();
1046 EXPECT_THAT(root,
1047 GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
1048 }
1049
TEST_F(LayoutAssignmentTest,PropagatingLayoutFromResultToOperand)1050 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
1051 const char* module_str = R"(
1052 HloModule PropagatingLayoutFromResultToOperand
1053
1054 ENTRY PropagatingLayoutFromResultToOperand {
1055 par0 = f32[4,5]{1,0} parameter(0)
1056 ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
1057 }
1058 )";
1059
1060 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1061 ParseAndReturnVerifiedModule(module_str));
1062 auto compiled_module =
1063 backend()
1064 .compiler()
1065 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1066 /*device_allocator=*/nullptr)
1067 .ConsumeValueOrDie();
1068 HloInstruction* root =
1069 compiled_module->entry_computation()->root_instruction();
1070 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
1071 EXPECT_THAT(root,
1072 GmockMatch(m::Slice(
1073 m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
1074 }
1075
TEST_F(LayoutAssignmentTest,TupleCopyOnLayoutMismatch)1076 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
1077 // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
1078 // The mismatch forces a copy of the tuple. The tuple contains a token, so
1079 // layout assignment will fail if it tries to copy the whole tuple.
1080 const char* module_str = R"(
1081 HloModule TupleCopyOnLayoutMismatch
1082
1083 condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
1084 tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1085 counter.1 = s32[] get-tuple-element(tup.1), index=0
1086 five = s32[] constant(5)
1087 ROOT lt = pred[] compare(counter.1, five), direction=LT
1088 }
1089
1090 body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
1091 tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1092 counter.2 = s32[] get-tuple-element(tup.2), index=0
1093 tok.2 = token[] get-tuple-element(tup.2), index=1
1094
1095 ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
1096 next_tok = token[] get-tuple-element(ifeed.2), index=1
1097 next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
1098
1099 one = s32[] constant(1)
1100 next_counter = s32[] add(counter.2, one)
1101 ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
1102 }
1103
1104 ENTRY main () -> f32[512,1024]{0,1} {
1105 start_tok = token[] after-all()
1106
1107 ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
1108 itok = token[] get-tuple-element(ifeed.3), index=1
1109 ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
1110
1111 zero = s32[] constant(0)
1112 itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
1113
1114 loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
1115 ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
1116 }
1117 )";
1118
1119 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1120 ParseAndReturnVerifiedModule(module_str));
1121 ComputationLayout computation_layout(
1122 m->entry_computation()->ComputeProgramShape());
1123
1124 // Sanity check to verify that there's a layout mismatch.
1125 EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1126 EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1127
1128 AssignLayouts(m.get(), &computation_layout);
1129
1130 // Make sure that layout assignment did not magically eliminate the mismatch,
1131 // in which case the test didn't prove anything.
1132 EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1133 EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1134 }
1135
TEST_F(LayoutAssignmentTest,CustomCallNotLayoutConstrained)1136 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
1137 const char* module_str = R"(
1138 HloModule CustomCallNotLayoutConstrained
1139
1140 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
1141 %p = f32[42,2,3] parameter(0)
1142 ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
1143 }
1144 )";
1145 // Try with a couple different layouts. In each case the custom calls operand
1146 // and result layout should match that of the computation.
1147 {
1148 TF_ASSERT_OK_AND_ASSIGN(
1149 std::unique_ptr<VerifiedHloModule> m,
1150 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1151 ComputationLayout computation_layout = m->entry_computation_layout();
1152 *computation_layout.mutable_parameter_layout(0) =
1153 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
1154 *computation_layout.mutable_result_layout() = ShapeLayout(
1155 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
1156 AssignLayouts(m.get(), &computation_layout);
1157
1158 HloInstruction* root = m->entry_computation()->root_instruction();
1159 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1160 ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
1161 ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
1162 }
1163 {
1164 TF_ASSERT_OK_AND_ASSIGN(
1165 std::unique_ptr<VerifiedHloModule> m,
1166 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1167 ComputationLayout computation_layout = m->entry_computation_layout();
1168 *computation_layout.mutable_parameter_layout(0) =
1169 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
1170 *computation_layout.mutable_result_layout() = ShapeLayout(
1171 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
1172 AssignLayouts(m.get(), &computation_layout);
1173
1174 HloInstruction* root = m->entry_computation()->root_instruction();
1175 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1176 ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
1177 ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
1178 }
1179 }
1180
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrained)1181 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
1182 const char* module_str = R"(
1183 HloModule CustomCallLayoutConstrained
1184
1185 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1186 %p0 = f32[4,4] parameter(0)
1187 %p1 = f32[2,3] parameter(1)
1188 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
1189 }
1190 )";
1191 TF_ASSERT_OK_AND_ASSIGN(
1192 std::unique_ptr<VerifiedHloModule> m,
1193 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1194 ComputationLayout computation_layout = m->entry_computation_layout();
1195 *computation_layout.mutable_parameter_layout(0) =
1196 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1197 *computation_layout.mutable_parameter_layout(1) =
1198 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1199 *computation_layout.mutable_result_layout() = ShapeLayout(
1200 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1201 AssignLayouts(m.get(), &computation_layout);
1202
1203 // The custom call should be partially encapsulated in kCopy instructions
1204 // because of the layout mismatches.
1205 ASSERT_THAT(m->entry_computation()->root_instruction(),
1206 GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
1207
1208 const HloInstruction* custom_call =
1209 m->entry_computation()->root_instruction()->operand(0);
1210 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1211 ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
1212 ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1213 }
1214
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedZeroOperands)1215 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
1216 const char* module_str = R"(
1217 HloModule CustomCallLayoutConstrainedZeroOperands
1218
1219 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
1220 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1221 }
1222 )";
1223 TF_ASSERT_OK_AND_ASSIGN(
1224 std::unique_ptr<VerifiedHloModule> m,
1225 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1226 ComputationLayout computation_layout = m->entry_computation_layout();
1227 *computation_layout.mutable_result_layout() = ShapeLayout(
1228 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1229 AssignLayouts(m.get(), &computation_layout);
1230
1231 ASSERT_THAT(m->entry_computation()->root_instruction(),
1232 GmockMatch(m::Copy(m::CustomCall())));
1233
1234 const HloInstruction* custom_call =
1235 m->entry_computation()->root_instruction()->operand(0);
1236 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1237 }
1238
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleOperand)1239 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
1240 const char* module_str = R"(
1241 HloModule CustomCallLayoutConstrainedTupleOperand
1242
1243 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1244 %p0 = f32[4,4] parameter(0)
1245 %p1 = f32[2,3] parameter(1)
1246 %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
1247 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
1248 }
1249 )";
1250 TF_ASSERT_OK_AND_ASSIGN(
1251 std::unique_ptr<VerifiedHloModule> m,
1252 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1253 ComputationLayout computation_layout = m->entry_computation_layout();
1254 *computation_layout.mutable_parameter_layout(0) =
1255 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1256 *computation_layout.mutable_parameter_layout(1) =
1257 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1258 *computation_layout.mutable_result_layout() = ShapeLayout(
1259 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1260 AssignLayouts(m.get(), &computation_layout);
1261
1262 HloInstruction* root = m->entry_computation()->root_instruction();
1263 ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
1264
1265 ASSERT_THAT(m->entry_computation()->root_instruction(),
1266 GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
1267
1268 const HloInstruction* custom_call =
1269 m->entry_computation()->root_instruction()->operand(0);
1270 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1271 ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
1272 }
1273
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleResult)1274 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
1275 const char* module_str = R"(
1276 HloModule CustomCallLayoutConstrainedTupleResult
1277
1278 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
1279 %p0 = f32[4,4] parameter(0)
1280 ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
1281 }
1282 )";
1283 // Try with a couple different layouts. In each case the custom calls operand
1284 // and result layout should match that of the computation.
1285 TF_ASSERT_OK_AND_ASSIGN(
1286 std::unique_ptr<VerifiedHloModule> m,
1287 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1288 ComputationLayout computation_layout = m->entry_computation_layout();
1289 *computation_layout.mutable_parameter_layout(0) =
1290 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1291 *computation_layout.mutable_result_layout() =
1292 ShapeLayout(ShapeUtil::MakeTupleShape(
1293 {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
1294 ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
1295 AssignLayouts(m.get(), &computation_layout);
1296
1297 ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
1298
1299 const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
1300 ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
1301 }
1302
AssignLayoutsToComputation(HloModule * m,ChannelLayoutConstraints * channel_constraints=nullptr)1303 Status AssignLayoutsToComputation(
1304 HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
1305 if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
1306 m->mutable_entry_computation_layout()
1307 ->mutable_result_layout()
1308 ->SetToDefaultLayout();
1309 }
1310 LayoutAssignment layout_assignment(
1311 m->mutable_entry_computation_layout(),
1312 LayoutAssignment::InstructionCanChangeLayout, channel_constraints);
1313 return layout_assignment.Run(m).status();
1314 }
1315
TEST_F(LayoutAssignmentTest,OverwriteDiamondShapedConstraintsX)1316 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
1317 // Check that we handle a diamond-shaped graph correctly.
1318 // transpose
1319 // / \
1320 // add |
1321 // \ /
1322 // tuple
1323
1324 auto b = HloComputation::Builder(TestName());
1325 Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
1326 Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
1327 auto param0 =
1328 b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
1329 auto param1 =
1330 b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
1331 auto transpose =
1332 b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
1333 auto add = b.AddInstruction(
1334 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
1335 b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
1336 auto m = CreateNewVerifiedModule();
1337 m->AddEntryComputation(b.Build());
1338 Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
1339 Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
1340 *m->mutable_entry_computation_layout()->mutable_result_layout() =
1341 ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
1342 const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
1343 ForceParameterLayout(m.get(), 0, r2_dim0major);
1344 ForceParameterLayout(m.get(), 1, r2_dim0major);
1345 TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
1346
1347 EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
1348 EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
1349 ElementsAre(1, 0));
1350 EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
1351 ElementsAre(1, 0));
1352
1353 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
1354 }
1355
1356 } // namespace
1357 } // namespace xla
1358