1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <initializer_list>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/test_helpers.h"
39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45
46 namespace xla {
47 namespace {
48
49 namespace m = xla::match;
50 using ::testing::ElementsAre;
51
52 class LayoutAssignmentTest : public HloTestBase {
53 protected:
AssignLayouts(HloModule * m,ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints=nullptr)54 void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
55 ChannelLayoutConstraints* channel_constraints = nullptr) {
56 LayoutAssignment layout_assignment(
57 entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
58 /*channel_constraints=*/channel_constraints);
59 EXPECT_IS_OK(layout_assignment.Run(m).status());
60 }
61
LayoutOf(HloModule * m,absl::string_view name)62 std::vector<int64> LayoutOf(HloModule* m, absl::string_view name) {
63 auto minor_to_major =
64 FindInstruction(m, name)->shape().layout().minor_to_major();
65 return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
66 }
67
ExpectLayoutIs(const Shape & shape,absl::Span<const int64> minor_to_major)68 void ExpectLayoutIs(const Shape& shape,
69 absl::Span<const int64> minor_to_major) {
70 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
71 EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
72 << "Expected layout " << expected << ", actual " << shape.layout();
73 }
74
ExpectTupleLayoutIs(const Shape & shape,std::initializer_list<absl::Span<const int64>> minor_to_majors)75 void ExpectTupleLayoutIs(
76 const Shape& shape,
77 std::initializer_list<absl::Span<const int64>> minor_to_majors) {
78 int i = 0;
79 for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
80 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
81 const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
82 EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
83 << "Expected tuple element " << i << " layout " << expected
84 << ", actual " << actual;
85 ++i;
86 }
87 }
88 };
89
TEST_F(LayoutAssignmentTest,ComputationLayout)90 TEST_F(LayoutAssignmentTest, ComputationLayout) {
91 // Verify the layouts of the root and parameter instructions of a computation
92 // match the ComputationLayout for two different layouts.
93 std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
94 for (auto& minor_to_major : minor_to_majors) {
95 auto builder = HloComputation::Builder(TestName());
96 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
97 auto param0 = builder.AddInstruction(
98 HloInstruction::CreateParameter(0, ashape, "param0"));
99 auto param1 = builder.AddInstruction(
100 HloInstruction::CreateParameter(1, ashape, "param1"));
101 auto add = builder.AddInstruction(
102 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
103 auto m = CreateNewVerifiedModule();
104 HloComputation* computation = m->AddEntryComputation(builder.Build());
105
106 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
107 Shape shape(ashape);
108 *shape.mutable_layout() = layout;
109 const ShapeLayout shape_layout(shape);
110
111 ComputationLayout computation_layout(computation->ComputeProgramShape());
112 *computation_layout.mutable_parameter_layout(0) = shape_layout;
113 *computation_layout.mutable_parameter_layout(1) = shape_layout;
114 *computation_layout.mutable_result_layout() = shape_layout;
115 AssignLayouts(m.get(), &computation_layout);
116 EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
117 EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
118 EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
119 }
120 }
121
TEST_F(LayoutAssignmentTest,ComputationLayoutMixedLayout)122 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
123 // Verify the layouts of the root and parameter instructions of a computation
124 // match the ComputationLayout which has mixed layout.
125 auto builder = HloComputation::Builder(TestName());
126 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
127 auto param0 = builder.AddInstruction(
128 HloInstruction::CreateParameter(0, ashape, "param0"));
129 auto param1 = builder.AddInstruction(
130 HloInstruction::CreateParameter(1, ashape, "param1"));
131 builder.AddInstruction(
132 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
133 auto m = CreateNewVerifiedModule();
134 HloComputation* computation = m->AddEntryComputation(builder.Build());
135
136 Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
137 Shape col_major_shape(ashape);
138 *col_major_shape.mutable_layout() = col_major_layout;
139 const ShapeLayout col_major(col_major_shape);
140
141 Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
142 Shape row_major_shape(ashape);
143 *row_major_shape.mutable_layout() = row_major_layout;
144 const ShapeLayout row_major(row_major_shape);
145
146 ComputationLayout computation_layout(computation->ComputeProgramShape());
147 *computation_layout.mutable_parameter_layout(0) = col_major;
148 *computation_layout.mutable_parameter_layout(1) = row_major;
149 *computation_layout.mutable_result_layout() = col_major;
150
151 AssignLayouts(m.get(), &computation_layout);
152 EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
153 EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
154 EXPECT_TRUE(LayoutUtil::Equal(
155 col_major_layout, computation->root_instruction()->shape().layout()));
156 }
157
TEST_F(LayoutAssignmentTest,FusionInstruction)158 TEST_F(LayoutAssignmentTest, FusionInstruction) {
159 // Verify that the layout of the fused parameters in a fusion instruction
160 // match that of the fusion operands. Other fused instructions should have no
161 // layout.
162 std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
163 for (auto& minor_to_major : minor_to_majors) {
164 auto builder = HloComputation::Builder(TestName());
165 auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
166 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
167 auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
168 {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
169 Shape ashape = constant_literal1.shape();
170
171 auto constant1 = builder.AddInstruction(
172 HloInstruction::CreateConstant(std::move(constant_literal1)));
173 auto constant2 = builder.AddInstruction(
174 HloInstruction::CreateConstant(std::move(constant_literal2)));
175 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
176 ashape, HloOpcode::kAdd, constant1, constant2));
177 auto negate1 = builder.AddInstruction(
178 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
179 auto negate2 = builder.AddInstruction(
180 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
181
182 auto m = CreateNewVerifiedModule();
183 HloComputation* computation = m->AddEntryComputation(builder.Build());
184
185 auto fusion = computation->CreateFusionInstruction(
186 {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
187
188 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
189 Shape shape(ashape);
190 *shape.mutable_layout() = layout;
191 const ShapeLayout shape_layout(shape);
192
193 ComputationLayout computation_layout(computation->ComputeProgramShape());
194 *computation_layout.mutable_result_layout() = shape_layout;
195
196 AssignLayouts(m.get(), &computation_layout);
197
198 EXPECT_TRUE(LayoutUtil::Equal(
199 layout, fusion->fused_parameter(0)->shape().layout()));
200 EXPECT_TRUE(LayoutUtil::Equal(
201 layout, fusion->fused_parameter(1)->shape().layout()));
202 EXPECT_TRUE(LayoutUtil::Equal(
203 layout, fusion->fused_expression_root()->shape().layout()));
204
205 // Inner fused node should not have layout.
206 EXPECT_FALSE(LayoutUtil::HasLayout(
207 fusion->fused_expression_root()->operand(0)->shape()));
208 }
209 }
210
TEST_F(LayoutAssignmentTest,TupleLayout)211 TEST_F(LayoutAssignmentTest, TupleLayout) {
212 // Verify the layouts of a tuple are assigned properly (the element layouts
213 // match their source).
214 auto builder = HloComputation::Builder(TestName());
215 auto constant0 = builder.AddInstruction(
216 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
217 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
218 auto constant1 = builder.AddInstruction(
219 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
220 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
221 auto tuple = builder.AddInstruction(
222 HloInstruction::CreateTuple({constant0, constant1}));
223
224 // To avoid having to construct a tuple layout in the ComputationLayout below,
225 // make the result of the instruction be an array.
226 auto get_element0 = builder.AddInstruction(
227 HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
228 auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
229 constant0->shape(), HloOpcode::kNegate, get_element0));
230
231 auto m = CreateNewVerifiedModule();
232 m->AddEntryComputation(builder.Build());
233
234 ComputationLayout computation_layout(
235 m->entry_computation()->ComputeProgramShape());
236
237 AssignLayouts(m.get(), &computation_layout);
238
239 EXPECT_TRUE(
240 LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
241
242 EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
243 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
244 negate->shape(), computation_layout.result_layout().shape()));
245 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
246 ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
247 }
248
TEST_F(LayoutAssignmentTest,TupleSelect)249 TEST_F(LayoutAssignmentTest, TupleSelect) {
250 // Verify layouts of a select with tuple operands is assigned properly.
251 auto builder = HloComputation::Builder(TestName());
252 auto constant0 = builder.AddInstruction(
253 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
254 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
255 auto constant1 = builder.AddInstruction(
256 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
257 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
258 auto tuple0 = builder.AddInstruction(
259 HloInstruction::CreateTuple({constant0, constant1}));
260 auto tuple1 = builder.AddInstruction(
261 HloInstruction::CreateTuple({constant0, constant1}));
262
263 auto pred = builder.AddInstruction(
264 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
265
266 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
267 tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
268
269 auto m = CreateNewVerifiedModule();
270 m->AddEntryComputation(builder.Build());
271
272 ComputationLayout computation_layout(
273 m->entry_computation()->ComputeProgramShape());
274 Shape result_shape =
275 ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
276 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
277 result_shape));
278
279 AssignLayouts(m.get(), &computation_layout);
280
281 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
282 }
283
TEST_F(LayoutAssignmentTest,ConflictingLayoutTuple)284 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
285 // Construct following computation which has conflicting layouts for two
286 // elements of a tuple which share the same source logicalb buffer:
287 //
288 // %constant = Constant(...)
289 // %inner_tuple = Tuple(%constant)
290 // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
291 //
292 // Result layout col-major for the first element and row-major for the
293 // second. This results in the conflict where the element of the inner_tuple
294 // needs to be both col and row major. This is resolved by deep-copying the
295 // tuple and assigning the layouts of the copied arrays as needed.
296 auto builder = HloComputation::Builder(TestName());
297 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
298 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
299 auto inner_tuple =
300 builder.AddInstruction(HloInstruction::CreateTuple({constant}));
301 auto nested_tuple = builder.AddInstruction(
302 HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
303
304 auto m = CreateNewVerifiedModule();
305 m->AddEntryComputation(builder.Build());
306
307 ComputationLayout computation_layout(
308 m->entry_computation()->ComputeProgramShape());
309 Shape result_shape = nested_tuple->shape();
310 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
311 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
312 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
313 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
314 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
315 result_shape));
316
317 LayoutAssignment layout_assignment(&computation_layout);
318 AssignLayouts(m.get(), &computation_layout);
319
320 // Layout assignment should have deep copied the result of the computation to
321 // address the layout conflict. This results in several Tuple() and
322 // GetTupleElement() instructions. Running algebraic simplification should
323 // clean up the code to something like:
324 //
325 // %constant = Constant(...) layout={1,0}
326 // %tuple.0 = Tuple(%constant) layout=({1,0})
327 // %copy = Copy(%constant) layout={0,1} # layout transposed
328 // %tuple.1 = Tuple(%copy) layout=({0,1})
329 // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
330 //
331 AlgebraicSimplifierOptions options(
332 [](const Shape&, const Shape&) { return false; });
333 options.set_is_layout_sensitive(true);
334 EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
335 HloInstruction* root = m->entry_computation()->root_instruction();
336 // Verify layout of the root and the root's operands.
337 EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
338 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
339 root->operand(0)->shape()));
340 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
341 root->operand(1)->shape()));
342
343 // Verify the structure of the HLO graph.
344 EXPECT_THAT(root,
345 GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
346 m::Tuple(m::Copy(m::Op().Is(constant))))));
347 }
348
TEST_F(LayoutAssignmentTest,ElementwiseAndReshape)349 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
350 // param -> log -> reshape -> tanh
351 auto builder = HloComputation::Builder(TestName());
352 Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
353 Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
354 auto param = builder.AddInstruction(
355 HloInstruction::CreateParameter(0, ashape, "param"));
356 auto log = builder.AddInstruction(
357 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
358 auto reshape =
359 builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
360 auto tanh = builder.AddInstruction(
361 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
362
363 auto m = CreateNewVerifiedModule();
364 HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
365
366 Shape ashape_with_layout(ashape);
367 Shape bshape_with_layout(bshape);
368 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
369 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
370
371 ComputationLayout computation_layout(computation->ComputeProgramShape());
372 *computation_layout.mutable_parameter_layout(0) =
373 ShapeLayout(ashape_with_layout);
374 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
375 AssignLayouts(m.get(), &computation_layout);
376
377 auto log_minor_to_major =
378 AsInt64Slice(log->shape().layout().minor_to_major());
379 EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
380 PositionInContainer(log_minor_to_major, 2));
381
382 auto reshape_minor_to_major =
383 AsInt64Slice(reshape->shape().layout().minor_to_major());
384 EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
385 PositionInContainer(reshape_minor_to_major, 2));
386 }
387
388 // Test whether LayoutAssignment assigns layouts to elementwise operations to
389 // keep linear indices valid across them, and to transpositions to make them
390 // bitcasts.
TEST_F(LayoutAssignmentTest,ElementwiseAndTranspose)391 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
392 // param -> log -> transpose -> tanh
393 auto builder = HloComputation::Builder(TestName());
394 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
395 Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
396 auto param = builder.AddInstruction(
397 HloInstruction::CreateParameter(0, ashape, "param"));
398 auto log = builder.AddInstruction(
399 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
400 auto transpose = builder.AddInstruction(
401 HloInstruction::CreateTranspose(bshape, log, {1, 0}));
402 auto tanh = builder.AddInstruction(
403 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
404 auto m = CreateNewVerifiedModule();
405 auto computation = m->AddEntryComputation(builder.Build(tanh));
406
407 Shape ashape_with_layout(ashape);
408 Shape bshape_with_layout(bshape);
409 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
410 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
411
412 ComputationLayout computation_layout(computation->ComputeProgramShape());
413 *computation_layout.mutable_parameter_layout(0) =
414 ShapeLayout(ashape_with_layout);
415 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
416 AssignLayouts(m.get(), &computation_layout);
417
418 EXPECT_TRUE(
419 LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
420 EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
421 transpose->shape().layout()));
422 EXPECT_TRUE(
423 LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
424 }
425
426 // Test whether LayoutAssignment assigns layouts to transpositions to make them
427 // bitcasts.
TEST_F(LayoutAssignmentTest,BroadcastAndTranspose)428 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
429 // param -> broadcast -> transpose
430 auto builder = HloComputation::Builder(TestName());
431 Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
432 Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
433 Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
434 auto param = builder.AddInstruction(
435 HloInstruction::CreateParameter(0, ashape, "param"));
436 auto broadcast = builder.AddInstruction(
437 HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
438 auto transpose = builder.AddInstruction(
439 HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
440 auto m = CreateNewVerifiedModule();
441 HloComputation* computation =
442 m->AddEntryComputation(builder.Build(transpose));
443
444 Shape input_shape_with_layout(ashape);
445 Shape output_shape_with_layout(cshape);
446 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
447 *output_shape_with_layout.mutable_layout() =
448 LayoutUtil::MakeLayout({2, 1, 0});
449
450 ComputationLayout computation_layout(computation->ComputeProgramShape());
451 *computation_layout.mutable_parameter_layout(0) =
452 ShapeLayout(input_shape_with_layout);
453 *computation_layout.mutable_result_layout() =
454 ShapeLayout(output_shape_with_layout);
455 AssignLayouts(m.get(), &computation_layout);
456
457 EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
458 ElementsAre(0, 1, 2));
459 }
460
TEST_F(LayoutAssignmentTest,ReshapeOperandHasMultipleUsers)461 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
462 // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
463 // \ /
464 // \-> tanh[3x4] -> broadcast2[2x3x4] -/
465 //
466 // The layout of `transpose` is set to {1,0} because it provides a buffer to
467 // the computation result which has a fixed layout.. Therefore, `broadcast`
468 // (the operand of transpose) is expected to have layout {0,1} so that the
469 // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
470 // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
471 Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
472 Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
473 Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
474 Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
475
476 auto builder = HloComputation::Builder(TestName());
477 auto param = builder.AddInstruction(
478 HloInstruction::CreateParameter(0, f32_4, "param"));
479 auto broadcast = builder.AddInstruction(
480 HloInstruction::CreateBroadcast(f32_34, param, {1}));
481 auto transpose = builder.AddInstruction(
482 HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
483 auto tanh = builder.AddInstruction(
484 HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
485 auto broadcast2 = builder.AddInstruction(
486 HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
487 auto tuple = builder.AddInstruction(
488 HloInstruction::CreateTuple({transpose, broadcast2}));
489 auto m = CreateNewVerifiedModule();
490 HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
491
492 ComputationLayout computation_layout(computation->ComputeProgramShape());
493 Shape param_shape_with_layout(f32_4);
494 Shape transpose_shape_with_layout(f32_43);
495 Shape broadcast2_shape_with_layout(f32_234);
496 *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
497 *transpose_shape_with_layout.mutable_layout() =
498 LayoutUtil::MakeLayout({1, 0});
499 *broadcast2_shape_with_layout.mutable_layout() =
500 LayoutUtil::MakeLayout({2, 1, 0});
501
502 *computation_layout.mutable_parameter_layout(0) =
503 ShapeLayout(param_shape_with_layout);
504 *computation_layout.mutable_result_layout() =
505 ShapeLayout(ShapeUtil::MakeTupleShape(
506 {transpose_shape_with_layout, broadcast2_shape_with_layout}));
507 AssignLayouts(m.get(), &computation_layout);
508
509 EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
510 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
511 EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
512 }
513
514 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
515 public:
OperandsMustBeTheSameLayoutAssignment(ComputationLayout * entry_computation_layout)516 explicit OperandsMustBeTheSameLayoutAssignment(
517 ComputationLayout* entry_computation_layout)
518 : LayoutAssignment(entry_computation_layout) {}
519
520 protected:
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)521 Status PropagateBufferConstraint(
522 const BufferLayoutConstraint& buffer_constraint,
523 LayoutConstraints* constraints) override {
524 const LogicalBuffer& buffer = buffer_constraint.buffer();
525 const HloInstruction* instruction = buffer.instruction();
526
527 // Force the operands' layout to the output layout.
528 for (int64 operand_no = 0; operand_no < instruction->operand_count();
529 ++operand_no) {
530 const HloInstruction* operand = instruction->operand(operand_no);
531 if (instruction->shape().rank() != operand->shape().rank()) {
532 continue;
533 }
534 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
535 buffer_constraint.layout(), instruction, operand_no,
536 /*mandatory=*/true));
537 }
538 return PropagateBufferConstraintToUses(buffer_constraint, constraints);
539 }
540 };
541
TEST_F(LayoutAssignmentTest,MakeOperandsTheSame)542 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
543 // param0 -> concatenate -> reshape
544 // param1 -^
545 auto builder = HloComputation::Builder(TestName());
546 Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
547 Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
548 Shape cshape = ShapeUtil::MakeShape(F32, {100});
549 auto param0 = builder.AddInstruction(
550 HloInstruction::CreateParameter(0, ashape, "param"));
551 auto param1 = builder.AddInstruction(
552 HloInstruction::CreateParameter(1, ashape, "param"));
553 auto concatenate = builder.AddInstruction(
554 HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
555 auto reshape = builder.AddInstruction(
556 HloInstruction::CreateReshape(cshape, concatenate));
557 auto m = CreateNewVerifiedModule();
558 HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
559
560 Shape param0_shape_with_layout(ashape);
561 Shape param1_shape_with_layout(ashape);
562 *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
563 *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
564
565 ComputationLayout computation_layout(computation->ComputeProgramShape());
566 *computation_layout.mutable_parameter_layout(0) =
567 ShapeLayout(param0_shape_with_layout);
568 *computation_layout.mutable_parameter_layout(1) =
569 ShapeLayout(param1_shape_with_layout);
570 OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
571 EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
572
573 EXPECT_EQ(concatenate->operand(0)->shape().layout().minor_to_major(),
574 concatenate->operand(1)->shape().layout().minor_to_major());
575 EXPECT_EQ(concatenate->shape().layout().minor_to_major(),
576 concatenate->operand(1)->shape().layout().minor_to_major());
577 }
578
579 // Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest,TransposeToBitcastFromOperand)580 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
581 auto builder = HloComputation::Builder(TestName());
582 Shape input_shape_with_layout =
583 ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
584 auto param = builder.AddInstruction(
585 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
586 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
587 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
588 auto m = CreateNewVerifiedModule();
589 HloComputation* computation =
590 m->AddEntryComputation(builder.Build(transpose));
591 ComputationLayout computation_layout(computation->ComputeProgramShape());
592 AssignLayouts(m.get(), &computation_layout);
593 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
594 transpose->shape(), {2, 3, 0, 1}));
595 }
596 // Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest,TransposeToBitcastToUser)597 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
598 auto builder = HloComputation::Builder(TestName());
599 Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
600 auto constant = builder.AddInstruction(
601 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
602 auto broadcast = builder.AddInstruction(
603 HloInstruction::CreateBroadcast(input_shape, constant, {}));
604 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
605 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
606 auto m = CreateNewVerifiedModule();
607 HloComputation* computation =
608 m->AddEntryComputation(builder.Build(transpose));
609 ComputationLayout computation_layout(computation->ComputeProgramShape());
610 AssignLayouts(m.get(), &computation_layout);
611 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
612 transpose->shape(), {2, 3, 0, 1}));
613 }
614
615 // TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,TransposeIsBitcastFail)616 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
617 auto builder = HloComputation::Builder(TestName());
618 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
619 Shape input_shape_with_layout(input_shape);
620 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
621 auto param = builder.AddInstruction(
622 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
623 auto hlo = builder.AddInstruction(
624 HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
625 // Clear the default layout assigned to the instruction.
626 LayoutUtil::ClearLayout(hlo->mutable_shape());
627 EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
628 hlo->shape(), hlo->dimensions()),
629 "LayoutUtil::HasLayout");
630 }
631
632 // ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,ReshapeIsBitcastFail)633 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
634 auto builder = HloComputation::Builder(TestName());
635 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
636 Shape input_shape_with_layout(input_shape);
637 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
638 auto param = builder.AddInstruction(
639 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
640 auto hlo =
641 builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
642 // Clear the default layout assigned to the instruction.
643 LayoutUtil::ClearLayout(hlo->mutable_shape());
644 EXPECT_DEATH(
645 ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
646 "LayoutUtil::HasLayout");
647 }
648
649 // Check that the computation below doesn't crash the compiler.
650 //
651 // Within a fusion computation, only the parameters and result get assigned a
652 // layout. When we run the algebraic simplifier on this computation post layout
653 // assignment, it should not call TransposeIsBitcast on the `transpose` node
654 // inside the fusion computation as TransposeIsBitcast checks both input_shape
655 // and output_shape have layouts.
TEST_F(LayoutAssignmentTest,TransposeWithinFusionDoesNotCrash)656 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
657 const char* module_str = R"(
658 HloModule test_module
659
660 fused_computation {
661 param_1 = f32[2,2,2]{2,1,0} parameter(1)
662 transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
663 reduce_1 = f32[] parameter(0)
664 broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
665 ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
666 }
667
668 ENTRY entry_computation {
669 fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
670 reduce.1 = f32[] parameter(0)
671 fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
672 ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
673 }
674 )";
675
676 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
677 ParseAndReturnVerifiedModule(module_str));
678 std::unique_ptr<HloModule> compiled_module =
679 backend()
680 .compiler()
681 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
682 /*device_allocator=*/nullptr)
683 .ConsumeValueOrDie();
684
685 EXPECT_EQ(Status::OK(), backend()
686 .compiler()
687 ->RunBackend(std::move(compiled_module),
688 backend().default_stream_executor(),
689 /*device_allocator=*/nullptr)
690 .status());
691 }
692
693 // A GTE inside of a fusion node inherits the layout of its operand (which
694 // should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest,GTEInheritsLayoutFromOperand)695 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
696 const char* module_str = R"(
697 HloModule test_module
698
699 fused_computation {
700 fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
701 gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
702 gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
703 gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
704 gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
705 add = f32[2,2,2] add(gte1a, gte1b)
706 ROOT fresult = f32[2,2,2] add(gte0, add)
707 }
708
709 ENTRY entry_computation {
710 param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
711 ROOT fusion =
712 f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
713 }
714 )";
715
716 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
717 ParseAndReturnVerifiedModule(module_str));
718 ComputationLayout computation_layout(
719 m->entry_computation()->ComputeProgramShape());
720 Shape param_shape = ShapeUtil::MakeTupleShape(
721 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
722 ShapeUtil::MakeTupleShape({
723 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
724 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
725 })});
726 TF_ASSERT_OK(
727 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
728 param_shape));
729 computation_layout.mutable_result_layout()->ResetLayout(
730 LayoutUtil::MakeLayout({2, 1, 0}));
731 AssignLayouts(m.get(), &computation_layout);
732
733 EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
734 EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
735 EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
736 EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
737 EXPECT_THAT(FindInstruction(m.get(), "gte1")
738 ->shape()
739 .tuple_shapes(0)
740 .layout()
741 .minor_to_major(),
742 ElementsAre(1, 2, 0));
743 EXPECT_THAT(FindInstruction(m.get(), "gte1")
744 ->shape()
745 .tuple_shapes(1)
746 .layout()
747 .minor_to_major(),
748 ElementsAre(2, 0, 1));
749 }
750
TEST_F(LayoutAssignmentTest,ConditionalAsymmetricLayout)751 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
752 auto builder = HloComputation::Builder(TestName());
753 auto m = CreateNewVerifiedModule();
754 Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
755 Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
756 Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
757
758 auto param0 = builder.AddInstruction(
759 HloInstruction::CreateParameter(0, shape, "param0"));
760 auto param1 = builder.AddInstruction(
761 HloInstruction::CreateParameter(1, shape, "param1"));
762 auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
763 2, ShapeUtil::MakeShape(PRED, {}), "param2"));
764 auto tuple =
765 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
766
767 auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
768 {
769 auto param = true_builder.AddInstruction(
770 HloInstruction::CreateParameter(0, tshape, "param"));
771 auto gte0 = true_builder.AddInstruction(
772 HloInstruction::CreateGetTupleElement(shape, param, 0));
773 auto gte1 = true_builder.AddInstruction(
774 HloInstruction::CreateGetTupleElement(shape, param, 1));
775 auto add = true_builder.AddInstruction(
776 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
777 true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
778 }
779 HloComputation* true_computation =
780 m->AddEmbeddedComputation(true_builder.Build());
781
782 auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
783 {
784 Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
785 false_builder.AddInstruction(
786 HloInstruction::CreateParameter(0, tshape, "param"));
787 // Using infeed as layout assignment does not mess up with it.
788 auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
789 auto infeed = false_builder.AddInstruction(
790 HloInstruction::CreateInfeed(xshape, token, ""));
791 auto infeed_data = false_builder.AddInstruction(
792 HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
793 false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
794 }
795 HloComputation* false_computation =
796 m->AddEmbeddedComputation(false_builder.Build());
797 builder.AddInstruction(HloInstruction::CreateConditional(
798 result_tshape, pred, tuple, true_computation, tuple, false_computation));
799
800 HloComputation* computation = m->AddEntryComputation(builder.Build());
801 ComputationLayout computation_layout(computation->ComputeProgramShape());
802
803 AssignLayouts(m.get(), &computation_layout);
804
805 const HloInstruction* true_root = true_computation->root_instruction();
806 const HloInstruction* false_root = false_computation->root_instruction();
807 EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
808 EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
809
810 const HloInstruction* true_result = true_root->operand(0);
811 const HloInstruction* false_result = false_root->operand(0);
812 EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
813 false_result->shape().layout()));
814 EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
815 }
816
TEST_F(LayoutAssignmentTest,InternalErrorOnBitcast)817 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
818 auto builder = HloComputation::Builder(TestName());
819 auto constant0 = builder.AddInstruction(
820 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
821 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
822 builder.AddInstruction(
823 HloInstruction::CreateBitcast(constant0->shape(), constant0));
824 auto m = CreateNewVerifiedModule();
825 m->AddEntryComputation(builder.Build());
826
827 ComputationLayout computation_layout(
828 m->entry_computation()->ComputeProgramShape());
829 LayoutAssignment layout_assignment(&computation_layout);
830 Status error_status = layout_assignment.Run(m.get()).status();
831 EXPECT_FALSE(error_status.ok());
832 EXPECT_THAT(
833 error_status.error_message(),
834 ::testing::HasSubstr(
835 "Unexpected bitcast operation seen during layout assignment"));
836 }
837
TEST_F(LayoutAssignmentTest,ChannelLayoutMismatch)838 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
839 // Pin non matching layouts to parameter and root.
840 const char* module_str = R"(
841 HloModule test_module
842
843 ENTRY entry_computation {
844 param = (f32[2,2]) parameter(0)
845 gte = f32[2,2] get-tuple-element(param), index=0
846 token0 = token[] after-all()
847 recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
848 recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
849 sharding={maximal device=1}
850 ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
851 send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
852 sharding={maximal device=0}
853 send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
854 }
855 )";
856
857 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
858 ParseAndReturnVerifiedModule(module_str));
859 ComputationLayout computation_layout(
860 m->entry_computation()->ComputeProgramShape());
861 Shape param_shape = ShapeUtil::MakeTupleShape(
862 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
863 TF_ASSERT_OK(
864 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
865 param_shape));
866 computation_layout.mutable_result_layout()->ResetLayout(
867 LayoutUtil::MakeLayout({1, 0}));
868
869 ChannelLayoutConstraints channel_constraints;
870 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
871
872 EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
873 FindInstruction(m.get(), "recv")->shape()));
874 }
875
TEST_F(LayoutAssignmentTest,AllReduceLayoutMissmatch)876 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
877 // Pin non matching layouts to parameter and root.
878 const char* module_str = R"(
879 HloModule test_module
880
881 add {
882 lhs = f32[] parameter(0)
883 rhs = f32[] parameter(1)
884 ROOT add = f32[] add(lhs, rhs)
885 }
886
887 ENTRY entry_computation {
888 param = (f32[2,2]) parameter(0)
889 gte = f32[2,2] get-tuple-element(param), index=0
890 ar.0 = f32[2,2] all-reduce(gte),
891 channel_id=1, replica_groups={{0}}, to_apply=add,
892 sharding={maximal device=0}
893 const = f32[2,2] constant({{0,1},{2,3}})
894 ROOT ar.1 = f32[2,2] all-reduce(const),
895 channel_id=1, replica_groups={{0}}, to_apply=add,
896 sharding={maximal device=1}
897 })";
898 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
899 ParseAndReturnVerifiedModule(module_str));
900 ComputationLayout computation_layout(
901 m->entry_computation()->ComputeProgramShape());
902 Shape param_shape = ShapeUtil::MakeTupleShape(
903 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
904 TF_ASSERT_OK(
905 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
906 param_shape));
907 computation_layout.mutable_result_layout()->ResetLayout(
908 LayoutUtil::MakeLayout({1, 0}));
909
910 ChannelLayoutConstraints channel_constraints;
911 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
912
913 EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
914 EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
915 EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
916 const HloInstruction* root = m->entry_computation()->root_instruction();
917 EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
918 }
919
TEST_F(LayoutAssignmentTest,CopySliceOperandToAvoidImplicitLayoutChange)920 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
921 const char* module_str = R"(
922 HloModule CopySliceOperandToAvoidImplicitLayoutChange
923
924 ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
925 par0 = f32[3,4]{1,0} parameter(0)
926 par1 = f32[4,5]{0,1} parameter(1)
927 slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
928 ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
929 }
930 )";
931
932 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
933 ParseAndReturnVerifiedModule(module_str));
934 auto compiled_module =
935 backend()
936 .compiler()
937 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
938 /*device_allocator=*/nullptr)
939 .ConsumeValueOrDie();
940 HloInstruction* root =
941 compiled_module->entry_computation()->root_instruction();
942 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
943 EXPECT_THAT(
944 root,
945 GmockMatch(m::Add(
946 m::Parameter(),
947 m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
948 }
949
TEST_F(LayoutAssignmentTest,CopyDSliceOperandToAvoidImplicitLayoutChange)950 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
951 const char* module_str = R"(
952 HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
953
954 ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
955 par0 = f32[3,4]{1,0} parameter(0)
956 par1 = f32[4,5]{0,1} parameter(1)
957 par2 = s32[] parameter(2)
958 par3 = s32[] parameter(3)
959 dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
960 ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
961 }
962 )";
963
964 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
965 ParseAndReturnVerifiedModule(module_str));
966 auto compiled_module =
967 backend()
968 .compiler()
969 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
970 /*device_allocator=*/nullptr)
971 .ConsumeValueOrDie();
972 HloInstruction* root =
973 compiled_module->entry_computation()->root_instruction();
974 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
975 EXPECT_THAT(root,
976 GmockMatch(m::Add(
977 m::Parameter(),
978 m::DynamicSlice(
979 m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
980 m::Parameter(2), m::Parameter(3)))));
981 }
982
TEST_F(LayoutAssignmentTest,CopyConcatOperandToAvoidImplicitLayoutChange)983 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
984 const char* module_str = R"(
985 HloModule CopyConcatOperandToAvoidImplicitLayoutChange
986
987 ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
988 par0 = f32[3,8]{1,0} parameter(0)
989 par1 = f32[3,5]{0,1} parameter(1)
990 par2 = f32[3,3]{1,0} parameter(2)
991 concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
992 dimensions={1}
993 ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
994 }
995 )";
996
997 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
998 ParseAndReturnVerifiedModule(module_str));
999 auto compiled_module =
1000 backend()
1001 .compiler()
1002 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1003 /*device_allocator=*/nullptr)
1004 .ConsumeValueOrDie();
1005 HloInstruction* root =
1006 compiled_module->entry_computation()->root_instruction();
1007 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
1008 EXPECT_THAT(
1009 root,
1010 GmockMatch(m::Add(
1011 m::Parameter(),
1012 m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
1013 m::Parameter(2)))));
1014 }
1015
TEST_F(LayoutAssignmentTest,ConvolutionOperandWithImplicitLayoutChangeNotCopied)1016 TEST_F(LayoutAssignmentTest,
1017 ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
1018 const char* module_str = R"(
1019 HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
1020
1021 ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
1022 par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
1023 par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
1024 ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
1025 window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
1026 feature_group_count=1
1027 }
1028 )";
1029
1030 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1031 ParseAndReturnVerifiedModule(module_str));
1032 auto compiled_module =
1033 backend()
1034 .compiler()
1035 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1036 /*device_allocator=*/nullptr)
1037 .ConsumeValueOrDie();
1038 HloInstruction* root =
1039 compiled_module->entry_computation()->root_instruction();
1040 EXPECT_THAT(root,
1041 GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
1042 }
1043
TEST_F(LayoutAssignmentTest,PropagatingLayoutFromResultToOperand)1044 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
1045 const char* module_str = R"(
1046 HloModule PropagatingLayoutFromResultToOperand
1047
1048 ENTRY PropagatingLayoutFromResultToOperand {
1049 par0 = f32[4,5]{1,0} parameter(0)
1050 ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
1051 }
1052 )";
1053
1054 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1055 ParseAndReturnVerifiedModule(module_str));
1056 auto compiled_module =
1057 backend()
1058 .compiler()
1059 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1060 /*device_allocator=*/nullptr)
1061 .ConsumeValueOrDie();
1062 HloInstruction* root =
1063 compiled_module->entry_computation()->root_instruction();
1064 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
1065 EXPECT_THAT(root,
1066 GmockMatch(m::Slice(
1067 m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
1068 }
1069
TEST_F(LayoutAssignmentTest,TupleCopyOnLayoutMismatch)1070 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
1071 // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
1072 // The mismatch forces a copy of the tuple. The tuple contains a token, so
1073 // layout assignment will fail if it tries to copy the whole tuple.
1074 const char* module_str = R"(
1075 HloModule TupleCopyOnLayoutMismatch
1076
1077 condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
1078 tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1079 counter.1 = s32[] get-tuple-element(tup.1), index=0
1080 five = s32[] constant(5)
1081 ROOT lt = pred[] compare(counter.1, five), direction=LT
1082 }
1083
1084 body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
1085 tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1086 counter.2 = s32[] get-tuple-element(tup.2), index=0
1087 tok.2 = token[] get-tuple-element(tup.2), index=1
1088
1089 ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
1090 next_tok = token[] get-tuple-element(ifeed.2), index=1
1091 next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
1092
1093 one = s32[] constant(1)
1094 next_counter = s32[] add(counter.2, one)
1095 ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
1096 }
1097
1098 ENTRY main () -> f32[512,1024]{0,1} {
1099 start_tok = token[] after-all()
1100
1101 ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
1102 itok = token[] get-tuple-element(ifeed.3), index=1
1103 ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
1104
1105 zero = s32[] constant(0)
1106 itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
1107
1108 loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
1109 ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
1110 }
1111 )";
1112
1113 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1114 ParseAndReturnVerifiedModule(module_str));
1115 ComputationLayout computation_layout(
1116 m->entry_computation()->ComputeProgramShape());
1117
1118 // Sanity check to verify that there's a layout mismatch.
1119 EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1120 EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1121
1122 AssignLayouts(m.get(), &computation_layout);
1123
1124 // Make sure that layout assignment did not magically eliminate the mismatch,
1125 // in which case the test didn't prove anything.
1126 EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1127 EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1128 }
1129
TEST_F(LayoutAssignmentTest,CustomCallNotLayoutConstrained)1130 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
1131 const char* module_str = R"(
1132 HloModule CustomCallNotLayoutConstrained
1133
1134 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
1135 %p = f32[42,2,3] parameter(0)
1136 ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
1137 }
1138 )";
1139 // Try with a couple different layouts. In each case the custom calls operand
1140 // and result layout should match that of the computation.
1141 {
1142 TF_ASSERT_OK_AND_ASSIGN(
1143 std::unique_ptr<VerifiedHloModule> m,
1144 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1145 ComputationLayout computation_layout = m->entry_computation_layout();
1146 *computation_layout.mutable_parameter_layout(0) =
1147 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
1148 *computation_layout.mutable_result_layout() = ShapeLayout(
1149 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
1150 AssignLayouts(m.get(), &computation_layout);
1151
1152 HloInstruction* root = m->entry_computation()->root_instruction();
1153 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1154 ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
1155 ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
1156 }
1157 {
1158 TF_ASSERT_OK_AND_ASSIGN(
1159 std::unique_ptr<VerifiedHloModule> m,
1160 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1161 ComputationLayout computation_layout = m->entry_computation_layout();
1162 *computation_layout.mutable_parameter_layout(0) =
1163 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
1164 *computation_layout.mutable_result_layout() = ShapeLayout(
1165 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
1166 AssignLayouts(m.get(), &computation_layout);
1167
1168 HloInstruction* root = m->entry_computation()->root_instruction();
1169 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1170 ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
1171 ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
1172 }
1173 }
1174
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrained)1175 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
1176 const char* module_str = R"(
1177 HloModule CustomCallLayoutConstrained
1178
1179 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1180 %p0 = f32[4,4] parameter(0)
1181 %p1 = f32[2,3] parameter(1)
1182 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
1183 }
1184 )";
1185 TF_ASSERT_OK_AND_ASSIGN(
1186 std::unique_ptr<VerifiedHloModule> m,
1187 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1188 ComputationLayout computation_layout = m->entry_computation_layout();
1189 *computation_layout.mutable_parameter_layout(0) =
1190 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1191 *computation_layout.mutable_parameter_layout(1) =
1192 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1193 *computation_layout.mutable_result_layout() = ShapeLayout(
1194 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1195 AssignLayouts(m.get(), &computation_layout);
1196
1197 // The custom call should be partially encapsulated in kCopy instructions
1198 // because of the layout mismatches.
1199 ASSERT_THAT(m->entry_computation()->root_instruction(),
1200 GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
1201
1202 const HloInstruction* custom_call =
1203 m->entry_computation()->root_instruction()->operand(0);
1204 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1205 ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
1206 ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1207 }
1208
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedZeroOperands)1209 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
1210 const char* module_str = R"(
1211 HloModule CustomCallLayoutConstrainedZeroOperands
1212
1213 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
1214 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1215 }
1216 )";
1217 TF_ASSERT_OK_AND_ASSIGN(
1218 std::unique_ptr<VerifiedHloModule> m,
1219 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1220 ComputationLayout computation_layout = m->entry_computation_layout();
1221 *computation_layout.mutable_result_layout() = ShapeLayout(
1222 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1223 AssignLayouts(m.get(), &computation_layout);
1224
1225 ASSERT_THAT(m->entry_computation()->root_instruction(),
1226 GmockMatch(m::Copy(m::CustomCall())));
1227
1228 const HloInstruction* custom_call =
1229 m->entry_computation()->root_instruction()->operand(0);
1230 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1231 }
1232
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleOperand)1233 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
1234 const char* module_str = R"(
1235 HloModule CustomCallLayoutConstrainedTupleOperand
1236
1237 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1238 %p0 = f32[4,4] parameter(0)
1239 %p1 = f32[2,3] parameter(1)
1240 %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
1241 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
1242 }
1243 )";
1244 TF_ASSERT_OK_AND_ASSIGN(
1245 std::unique_ptr<VerifiedHloModule> m,
1246 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1247 ComputationLayout computation_layout = m->entry_computation_layout();
1248 *computation_layout.mutable_parameter_layout(0) =
1249 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1250 *computation_layout.mutable_parameter_layout(1) =
1251 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1252 *computation_layout.mutable_result_layout() = ShapeLayout(
1253 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1254 AssignLayouts(m.get(), &computation_layout);
1255
1256 HloInstruction* root = m->entry_computation()->root_instruction();
1257 ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
1258
1259 ASSERT_THAT(m->entry_computation()->root_instruction(),
1260 GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
1261
1262 const HloInstruction* custom_call =
1263 m->entry_computation()->root_instruction()->operand(0);
1264 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1265 ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
1266 }
1267
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleResult)1268 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
1269 const char* module_str = R"(
1270 HloModule CustomCallLayoutConstrainedTupleResult
1271
1272 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
1273 %p0 = f32[4,4] parameter(0)
1274 ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
1275 }
1276 )";
1277 // Try with a couple different layouts. In each case the custom calls operand
1278 // and result layout should match that of the computation.
1279 TF_ASSERT_OK_AND_ASSIGN(
1280 std::unique_ptr<VerifiedHloModule> m,
1281 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1282 ComputationLayout computation_layout = m->entry_computation_layout();
1283 *computation_layout.mutable_parameter_layout(0) =
1284 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1285 *computation_layout.mutable_result_layout() =
1286 ShapeLayout(ShapeUtil::MakeTupleShape(
1287 {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
1288 ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
1289 AssignLayouts(m.get(), &computation_layout);
1290
1291 ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
1292
1293 const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
1294 ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
1295 }
1296
AssignLayoutsToComputation(HloModule * m,ChannelLayoutConstraints * channel_constraints=nullptr)1297 Status AssignLayoutsToComputation(
1298 HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
1299 if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
1300 m->mutable_entry_computation_layout()
1301 ->mutable_result_layout()
1302 ->SetToDefaultLayout();
1303 }
1304 LayoutAssignment layout_assignment(
1305 m->mutable_entry_computation_layout(),
1306 LayoutAssignment::InstructionCanChangeLayout, channel_constraints);
1307 return layout_assignment.Run(m).status();
1308 }
1309
TEST_F(LayoutAssignmentTest,OverwriteDiamondShapedConstraintsX)1310 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
1311 // Check that we handle a diamond-shaped graph correctly.
1312 // transpose
1313 // / \
1314 // add |
1315 // \ /
1316 // tuple
1317
1318 auto b = HloComputation::Builder(TestName());
1319 Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
1320 Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
1321 auto param0 =
1322 b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
1323 auto param1 =
1324 b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
1325 auto transpose =
1326 b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
1327 auto add = b.AddInstruction(
1328 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
1329 b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
1330 auto m = CreateNewVerifiedModule();
1331 m->AddEntryComputation(b.Build());
1332 Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
1333 Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
1334 *m->mutable_entry_computation_layout()->mutable_result_layout() =
1335 ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
1336 const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
1337 ForceParameterLayout(m.get(), 0, r2_dim0major);
1338 ForceParameterLayout(m.get(), 1, r2_dim0major);
1339 TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
1340
1341 EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
1342 EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
1343 ElementsAre(1, 0));
1344 EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
1345 ElementsAre(1, 0));
1346
1347 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
1348 }
1349
1350 // Tests that the layout assignment supports layout-constrained all-reduce with
1351 // different operand layouts (b/146056839).
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllReduce)1352 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllReduce) {
1353 const char* module_str = R"(
1354 HloModule test_module
1355
1356 add {
1357 lhs = f32[] parameter(0)
1358 rhs = f32[] parameter(1)
1359 ROOT add = f32[] add(lhs, rhs)
1360 }
1361
1362 ENTRY entry_computation {
1363 param = (f32[8,4]{0,1}, f32[16,2]{0,1}) parameter(0)
1364 gte0 = f32[8,4] get-tuple-element(param), index=0
1365 gte1 = f32[16,2] get-tuple-element(param), index=1
1366 crs = (f32[8,4]{0,1}, f32[16,2]{1,0}) all-reduce(gte0, gte1),
1367 replica_groups={}, constrain_layout=true, to_apply=add
1368 gte2 = f32[8,4] get-tuple-element(crs), index=0
1369 gte3 = f32[16,2] get-tuple-element(crs), index=1
1370 ROOT result = (f32[8,4]{1,0}, f32[16,2]{1,0}) tuple(gte2, gte3)
1371 }
1372 )";
1373
1374 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1375 ParseAndReturnVerifiedModule(module_str));
1376 ComputationLayout computation_layout(
1377 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1378
1379 ChannelLayoutConstraints channel_constraints;
1380 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1381
1382 const HloInstruction* crs = FindInstruction(m.get(), "crs");
1383 ExpectTupleLayoutIs(crs->shape(), {{0, 1}, {1, 0}});
1384 ExpectLayoutIs(crs->operand(0)->shape(), {0, 1});
1385 ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
1386 }
1387
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllToAll)1388 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
1389 const char* module_str = R"(
1390 HloModule test_module
1391
1392 add {
1393 lhs = f32[] parameter(0)
1394 rhs = f32[] parameter(1)
1395 ROOT add = f32[] add(lhs, rhs)
1396 }
1397
1398 ENTRY entry_computation {
1399 param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
1400 gte0 = f32[16,4] get-tuple-element(param), index=0
1401 gte1 = f32[16,4] get-tuple-element(param), index=1
1402 alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
1403 replica_groups={{0,1}}, constrain_layout=true, to_apply=add
1404 gte2 = f32[16,4] get-tuple-element(alltoall), index=0
1405 gte3 = f32[16,4] get-tuple-element(alltoall), index=1
1406 ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
1407 }
1408 )";
1409
1410 TF_ASSERT_OK_AND_ASSIGN(
1411 std::unique_ptr<HloModule> m,
1412 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1413 ComputationLayout computation_layout(
1414 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1415
1416 ChannelLayoutConstraints channel_constraints;
1417 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1418
1419 const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
1420 ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
1421 ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
1422 ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
1423 }
1424
TEST_F(LayoutAssignmentTest,DynamicRoot)1425 TEST_F(LayoutAssignmentTest, DynamicRoot) {
1426 const char* module_str = R"(
1427 HloModule test_module
1428
1429 ENTRY entry_computation {
1430 param = f32[1,<=16]{0,1} parameter(0)
1431 ROOT abs = f32[1,<=16]{0,1} abs(param)
1432 }
1433 )";
1434
1435 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1436 ParseAndReturnVerifiedModule(module_str));
1437 ComputationLayout computation_layout(
1438 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1439 computation_layout.mutable_result_layout()->ClearDynamicShape();
1440
1441 AssignLayouts(m.get(), &computation_layout);
1442
1443 const HloInstruction* abs = FindInstruction(m.get(), "abs");
1444 ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
1445 ExpectLayoutIs(abs->shape(), {0, 1});
1446 EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
1447 }
1448
1449 } // namespace
1450 } // namespace xla
1451