1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17
18 #include <initializer_list>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/test_helpers.h"
39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45
46 namespace xla {
47 namespace {
48
49 namespace m = xla::match;
50 using ::testing::ElementsAre;
51
52 class LayoutAssignmentTest : public HloTestBase {
53 protected:
AssignLayouts(HloModule * m,ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints=nullptr)54 void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
55 ChannelLayoutConstraints* channel_constraints = nullptr) {
56 LayoutAssignment layout_assignment(
57 entry_computation_layout,
58 /*channel_constraints=*/channel_constraints);
59 EXPECT_IS_OK(layout_assignment.Run(m).status());
60 }
61
LayoutOf(HloModule * m,absl::string_view name)62 std::vector<int64> LayoutOf(HloModule* m, absl::string_view name) {
63 auto minor_to_major =
64 FindInstruction(m, name)->shape().layout().minor_to_major();
65 return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
66 }
67
ExpectLayoutIs(const Shape & shape,absl::Span<const int64> minor_to_major)68 void ExpectLayoutIs(const Shape& shape,
69 absl::Span<const int64> minor_to_major) {
70 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
71 EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
72 << "Expected layout " << expected << ", actual " << shape.layout();
73 }
74
ExpectTupleLayoutIs(const Shape & shape,std::initializer_list<absl::Span<const int64>> minor_to_majors)75 void ExpectTupleLayoutIs(
76 const Shape& shape,
77 std::initializer_list<absl::Span<const int64>> minor_to_majors) {
78 int i = 0;
79 for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
80 const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
81 const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
82 EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
83 << "Expected tuple element " << i << " layout " << expected
84 << ", actual " << actual;
85 ++i;
86 }
87 }
88 };
89
TEST_F(LayoutAssignmentTest,ComputationLayout)90 TEST_F(LayoutAssignmentTest, ComputationLayout) {
91 // Verify the layouts of the root and parameter instructions of a computation
92 // match the ComputationLayout for two different layouts.
93 std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
94 for (auto& minor_to_major : minor_to_majors) {
95 auto builder = HloComputation::Builder(TestName());
96 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
97 auto param0 = builder.AddInstruction(
98 HloInstruction::CreateParameter(0, ashape, "param0"));
99 auto param1 = builder.AddInstruction(
100 HloInstruction::CreateParameter(1, ashape, "param1"));
101 auto add = builder.AddInstruction(
102 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
103 auto m = CreateNewVerifiedModule();
104 HloComputation* computation = m->AddEntryComputation(builder.Build());
105
106 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
107 Shape shape(ashape);
108 *shape.mutable_layout() = layout;
109 const ShapeLayout shape_layout(shape);
110
111 ComputationLayout computation_layout(computation->ComputeProgramShape());
112 *computation_layout.mutable_parameter_layout(0) = shape_layout;
113 *computation_layout.mutable_parameter_layout(1) = shape_layout;
114 *computation_layout.mutable_result_layout() = shape_layout;
115 AssignLayouts(m.get(), &computation_layout);
116 EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
117 EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
118 EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
119 }
120 }
121
TEST_F(LayoutAssignmentTest,ComputationLayoutMixedLayout)122 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
123 // Verify the layouts of the root and parameter instructions of a computation
124 // match the ComputationLayout which has mixed layout.
125 auto builder = HloComputation::Builder(TestName());
126 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
127 auto param0 = builder.AddInstruction(
128 HloInstruction::CreateParameter(0, ashape, "param0"));
129 auto param1 = builder.AddInstruction(
130 HloInstruction::CreateParameter(1, ashape, "param1"));
131 builder.AddInstruction(
132 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
133 auto m = CreateNewVerifiedModule();
134 HloComputation* computation = m->AddEntryComputation(builder.Build());
135
136 Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
137 Shape col_major_shape(ashape);
138 *col_major_shape.mutable_layout() = col_major_layout;
139 const ShapeLayout col_major(col_major_shape);
140
141 Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
142 Shape row_major_shape(ashape);
143 *row_major_shape.mutable_layout() = row_major_layout;
144 const ShapeLayout row_major(row_major_shape);
145
146 ComputationLayout computation_layout(computation->ComputeProgramShape());
147 *computation_layout.mutable_parameter_layout(0) = col_major;
148 *computation_layout.mutable_parameter_layout(1) = row_major;
149 *computation_layout.mutable_result_layout() = col_major;
150
151 AssignLayouts(m.get(), &computation_layout);
152 EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
153 EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
154 EXPECT_TRUE(LayoutUtil::Equal(
155 col_major_layout, computation->root_instruction()->shape().layout()));
156 }
157
TEST_F(LayoutAssignmentTest,FusionInstruction)158 TEST_F(LayoutAssignmentTest, FusionInstruction) {
159 // Verify that the layout of the fused parameters in a fusion instruction
160 // match that of the fusion operands. Other fused instructions should have no
161 // layout.
162 std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
163 for (auto& minor_to_major : minor_to_majors) {
164 auto builder = HloComputation::Builder(TestName());
165 auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
166 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
167 auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
168 {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
169 Shape ashape = constant_literal1.shape();
170
171 auto constant1 = builder.AddInstruction(
172 HloInstruction::CreateConstant(std::move(constant_literal1)));
173 auto constant2 = builder.AddInstruction(
174 HloInstruction::CreateConstant(std::move(constant_literal2)));
175 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
176 ashape, HloOpcode::kAdd, constant1, constant2));
177 auto negate1 = builder.AddInstruction(
178 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
179 auto negate2 = builder.AddInstruction(
180 HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
181
182 auto m = CreateNewVerifiedModule();
183 HloComputation* computation = m->AddEntryComputation(builder.Build());
184
185 auto fusion = computation->CreateFusionInstruction(
186 {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
187
188 Layout layout = LayoutUtil::MakeLayout(minor_to_major);
189 Shape shape(ashape);
190 *shape.mutable_layout() = layout;
191 const ShapeLayout shape_layout(shape);
192
193 ComputationLayout computation_layout(computation->ComputeProgramShape());
194 *computation_layout.mutable_result_layout() = shape_layout;
195
196 AssignLayouts(m.get(), &computation_layout);
197
198 EXPECT_TRUE(LayoutUtil::Equal(
199 layout, fusion->fused_parameter(0)->shape().layout()));
200 EXPECT_TRUE(LayoutUtil::Equal(
201 layout, fusion->fused_parameter(1)->shape().layout()));
202 EXPECT_TRUE(LayoutUtil::Equal(
203 layout, fusion->fused_expression_root()->shape().layout()));
204
205 // Inner fused node should not have layout.
206 EXPECT_FALSE(LayoutUtil::HasLayout(
207 fusion->fused_expression_root()->operand(0)->shape()));
208 }
209 }
210
TEST_F(LayoutAssignmentTest,TupleLayout)211 TEST_F(LayoutAssignmentTest, TupleLayout) {
212 // Verify the layouts of a tuple are assigned properly (the element layouts
213 // match their source).
214 auto builder = HloComputation::Builder(TestName());
215 auto constant0 = builder.AddInstruction(
216 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
217 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
218 auto constant1 = builder.AddInstruction(
219 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
220 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
221 auto tuple = builder.AddInstruction(
222 HloInstruction::CreateTuple({constant0, constant1}));
223
224 // To avoid having to construct a tuple layout in the ComputationLayout below,
225 // make the result of the instruction be an array.
226 auto get_element0 = builder.AddInstruction(
227 HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
228 auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
229 constant0->shape(), HloOpcode::kNegate, get_element0));
230
231 auto m = CreateNewVerifiedModule();
232 m->AddEntryComputation(builder.Build());
233
234 ComputationLayout computation_layout(
235 m->entry_computation()->ComputeProgramShape());
236
237 AssignLayouts(m.get(), &computation_layout);
238
239 EXPECT_TRUE(
240 LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
241
242 EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
243 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
244 negate->shape(), computation_layout.result_layout().shape()));
245 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
246 ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
247 }
248
TEST_F(LayoutAssignmentTest,TupleSelect)249 TEST_F(LayoutAssignmentTest, TupleSelect) {
250 // Verify layouts of a select with tuple operands is assigned properly.
251 auto builder = HloComputation::Builder(TestName());
252 auto constant0 = builder.AddInstruction(
253 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
254 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
255 auto constant1 = builder.AddInstruction(
256 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
257 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
258 auto tuple0 = builder.AddInstruction(
259 HloInstruction::CreateTuple({constant0, constant1}));
260 auto tuple1 = builder.AddInstruction(
261 HloInstruction::CreateTuple({constant0, constant1}));
262
263 auto pred = builder.AddInstruction(
264 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
265
266 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
267 tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
268
269 auto m = CreateNewVerifiedModule();
270 m->AddEntryComputation(builder.Build());
271
272 ComputationLayout computation_layout(
273 m->entry_computation()->ComputeProgramShape());
274 Shape result_shape =
275 ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
276 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
277 result_shape));
278
279 AssignLayouts(m.get(), &computation_layout);
280
281 EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
282 }
283
TEST_F(LayoutAssignmentTest,ConflictingLayoutTuple)284 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
285 // Construct following computation which has conflicting layouts for two
286 // elements of a tuple which share the same source logicalb buffer:
287 //
288 // %constant = Constant(...)
289 // %inner_tuple = Tuple(%constant)
290 // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
291 //
292 // Result layout col-major for the first element and row-major for the
293 // second. This results in the conflict where the element of the inner_tuple
294 // needs to be both col and row major. This is resolved by deep-copying the
295 // tuple and assigning the layouts of the copied arrays as needed.
296 auto builder = HloComputation::Builder(TestName());
297 auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
298 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
299 auto inner_tuple =
300 builder.AddInstruction(HloInstruction::CreateTuple({constant}));
301 auto nested_tuple = builder.AddInstruction(
302 HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
303
304 auto m = CreateNewVerifiedModule();
305 m->AddEntryComputation(builder.Build());
306
307 ComputationLayout computation_layout(
308 m->entry_computation()->ComputeProgramShape());
309 Shape result_shape = nested_tuple->shape();
310 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
311 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
312 *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
313 ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
314 TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
315 result_shape));
316
317 LayoutAssignment layout_assignment(&computation_layout);
318 AssignLayouts(m.get(), &computation_layout);
319
320 // Layout assignment should have deep copied the result of the computation to
321 // address the layout conflict. This results in several Tuple() and
322 // GetTupleElement() instructions. Running algebraic simplification should
323 // clean up the code to something like:
324 //
325 // %constant = Constant(...) layout={1,0}
326 // %tuple.0 = Tuple(%constant) layout=({1,0})
327 // %copy = Copy(%constant) layout={0,1} # layout transposed
328 // %tuple.1 = Tuple(%copy) layout=({0,1})
329 // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
330 //
331 AlgebraicSimplifierOptions options(
332 [](const Shape&, const Shape&) { return false; });
333 options.set_is_layout_sensitive(true);
334 EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
335 HloInstruction* root = m->entry_computation()->root_instruction();
336 // Verify layout of the root and the root's operands.
337 EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
338 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
339 root->operand(0)->shape()));
340 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
341 root->operand(1)->shape()));
342
343 // Verify the structure of the HLO graph.
344 EXPECT_THAT(root,
345 GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
346 m::Tuple(m::Copy(m::Op().Is(constant))))));
347 }
348
TEST_F(LayoutAssignmentTest,ElementwiseAndReshape)349 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
350 // param -> log -> reshape -> tanh
351 auto builder = HloComputation::Builder(TestName());
352 Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
353 Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
354 auto param = builder.AddInstruction(
355 HloInstruction::CreateParameter(0, ashape, "param"));
356 auto log = builder.AddInstruction(
357 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
358 auto reshape =
359 builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
360 auto tanh = builder.AddInstruction(
361 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
362
363 auto m = CreateNewVerifiedModule();
364 HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
365
366 Shape ashape_with_layout(ashape);
367 Shape bshape_with_layout(bshape);
368 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
369 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
370
371 ComputationLayout computation_layout(computation->ComputeProgramShape());
372 *computation_layout.mutable_parameter_layout(0) =
373 ShapeLayout(ashape_with_layout);
374 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
375 AssignLayouts(m.get(), &computation_layout);
376
377 auto log_minor_to_major =
378 AsInt64Slice(log->shape().layout().minor_to_major());
379 EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
380 PositionInContainer(log_minor_to_major, 2));
381
382 auto reshape_minor_to_major =
383 AsInt64Slice(reshape->shape().layout().minor_to_major());
384 EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
385 PositionInContainer(reshape_minor_to_major, 2));
386 }
387
388 // Test whether LayoutAssignment assigns layouts to elementwise operations to
389 // keep linear indices valid across them, and to transpositions to make them
390 // bitcasts.
TEST_F(LayoutAssignmentTest,ElementwiseAndTranspose)391 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
392 // param -> log -> transpose -> tanh
393 auto builder = HloComputation::Builder(TestName());
394 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
395 Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
396 auto param = builder.AddInstruction(
397 HloInstruction::CreateParameter(0, ashape, "param"));
398 auto log = builder.AddInstruction(
399 HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
400 auto transpose = builder.AddInstruction(
401 HloInstruction::CreateTranspose(bshape, log, {1, 0}));
402 auto tanh = builder.AddInstruction(
403 HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
404 auto m = CreateNewVerifiedModule();
405 auto computation = m->AddEntryComputation(builder.Build(tanh));
406
407 Shape ashape_with_layout(ashape);
408 Shape bshape_with_layout(bshape);
409 *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
410 *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
411
412 ComputationLayout computation_layout(computation->ComputeProgramShape());
413 *computation_layout.mutable_parameter_layout(0) =
414 ShapeLayout(ashape_with_layout);
415 *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
416 AssignLayouts(m.get(), &computation_layout);
417
418 EXPECT_TRUE(
419 LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
420 EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
421 transpose->shape().layout()));
422 EXPECT_TRUE(
423 LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
424 }
425
426 // Test whether LayoutAssignment assigns layouts to transpositions to make them
427 // bitcasts.
TEST_F(LayoutAssignmentTest,BroadcastAndTranspose)428 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
429 // param -> broadcast -> transpose
430 auto builder = HloComputation::Builder(TestName());
431 Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
432 Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
433 Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
434 auto param = builder.AddInstruction(
435 HloInstruction::CreateParameter(0, ashape, "param"));
436 auto broadcast = builder.AddInstruction(
437 HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
438 auto transpose = builder.AddInstruction(
439 HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
440 auto m = CreateNewVerifiedModule();
441 HloComputation* computation =
442 m->AddEntryComputation(builder.Build(transpose));
443
444 Shape input_shape_with_layout(ashape);
445 Shape output_shape_with_layout(cshape);
446 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
447 *output_shape_with_layout.mutable_layout() =
448 LayoutUtil::MakeLayout({2, 1, 0});
449
450 ComputationLayout computation_layout(computation->ComputeProgramShape());
451 *computation_layout.mutable_parameter_layout(0) =
452 ShapeLayout(input_shape_with_layout);
453 *computation_layout.mutable_result_layout() =
454 ShapeLayout(output_shape_with_layout);
455 AssignLayouts(m.get(), &computation_layout);
456
457 EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
458 ElementsAre(0, 1, 2));
459 }
460
TEST_F(LayoutAssignmentTest,ReshapeOperandHasMultipleUsers)461 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
462 // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
463 // \ /
464 // \-> tanh[3x4] -> broadcast2[2x3x4] -/
465 //
466 // The layout of `transpose` is set to {1,0} because it provides a buffer to
467 // the computation result which has a fixed layout.. Therefore, `broadcast`
468 // (the operand of transpose) is expected to have layout {0,1} so that the
469 // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
470 // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
471 Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
472 Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
473 Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
474 Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
475
476 auto builder = HloComputation::Builder(TestName());
477 auto param = builder.AddInstruction(
478 HloInstruction::CreateParameter(0, f32_4, "param"));
479 auto broadcast = builder.AddInstruction(
480 HloInstruction::CreateBroadcast(f32_34, param, {1}));
481 auto transpose = builder.AddInstruction(
482 HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
483 auto tanh = builder.AddInstruction(
484 HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
485 auto broadcast2 = builder.AddInstruction(
486 HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
487 auto tuple = builder.AddInstruction(
488 HloInstruction::CreateTuple({transpose, broadcast2}));
489 auto m = CreateNewVerifiedModule();
490 HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
491
492 ComputationLayout computation_layout(computation->ComputeProgramShape());
493 Shape param_shape_with_layout(f32_4);
494 Shape transpose_shape_with_layout(f32_43);
495 Shape broadcast2_shape_with_layout(f32_234);
496 *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
497 *transpose_shape_with_layout.mutable_layout() =
498 LayoutUtil::MakeLayout({1, 0});
499 *broadcast2_shape_with_layout.mutable_layout() =
500 LayoutUtil::MakeLayout({2, 1, 0});
501
502 *computation_layout.mutable_parameter_layout(0) =
503 ShapeLayout(param_shape_with_layout);
504 *computation_layout.mutable_result_layout() =
505 ShapeLayout(ShapeUtil::MakeTupleShape(
506 {transpose_shape_with_layout, broadcast2_shape_with_layout}));
507 AssignLayouts(m.get(), &computation_layout);
508
509 EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
510 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
511 EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
512 }
513
514 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
515 public:
OperandsMustBeTheSameLayoutAssignment(ComputationLayout * entry_computation_layout)516 explicit OperandsMustBeTheSameLayoutAssignment(
517 ComputationLayout* entry_computation_layout)
518 : LayoutAssignment(entry_computation_layout) {}
519
520 protected:
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)521 Status PropagateBufferConstraint(
522 const BufferLayoutConstraint& buffer_constraint,
523 LayoutConstraints* constraints) override {
524 const LogicalBuffer& buffer = buffer_constraint.buffer();
525 const HloInstruction* instruction = buffer.instruction();
526
527 // Force the operands' layout to the output layout.
528 for (int64_t operand_no = 0; operand_no < instruction->operand_count();
529 ++operand_no) {
530 const HloInstruction* operand = instruction->operand(operand_no);
531 if (instruction->shape().rank() != operand->shape().rank()) {
532 continue;
533 }
534 TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
535 buffer_constraint.layout(), instruction, operand_no,
536 /*mandatory=*/true));
537 }
538 return PropagateBufferConstraintToUses(buffer_constraint, constraints);
539 }
540 };
541
TEST_F(LayoutAssignmentTest,MakeOperandsTheSame)542 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
543 // param0 -> concatenate -> reshape
544 // param1 -^
545 auto builder = HloComputation::Builder(TestName());
546 Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
547 Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
548 Shape cshape = ShapeUtil::MakeShape(F32, {100});
549 auto param0 = builder.AddInstruction(
550 HloInstruction::CreateParameter(0, ashape, "param"));
551 auto param1 = builder.AddInstruction(
552 HloInstruction::CreateParameter(1, ashape, "param"));
553 auto concatenate = builder.AddInstruction(
554 HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
555 auto reshape = builder.AddInstruction(
556 HloInstruction::CreateReshape(cshape, concatenate));
557 auto m = CreateNewVerifiedModule();
558 HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
559
560 Shape param0_shape_with_layout(ashape);
561 Shape param1_shape_with_layout(ashape);
562 *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
563 *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
564
565 ComputationLayout computation_layout(computation->ComputeProgramShape());
566 *computation_layout.mutable_parameter_layout(0) =
567 ShapeLayout(param0_shape_with_layout);
568 *computation_layout.mutable_parameter_layout(1) =
569 ShapeLayout(param1_shape_with_layout);
570 OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
571 EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
572
573 EXPECT_EQ(concatenate->operand(0)->shape().layout().minor_to_major(),
574 concatenate->operand(1)->shape().layout().minor_to_major());
575 EXPECT_EQ(concatenate->shape().layout().minor_to_major(),
576 concatenate->operand(1)->shape().layout().minor_to_major());
577 }
578
579 // Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest,TransposeToBitcastFromOperand)580 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
581 auto builder = HloComputation::Builder(TestName());
582 Shape input_shape_with_layout =
583 ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
584 auto param = builder.AddInstruction(
585 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
586 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
587 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
588 auto m = CreateNewVerifiedModule();
589 HloComputation* computation =
590 m->AddEntryComputation(builder.Build(transpose));
591 ComputationLayout computation_layout(computation->ComputeProgramShape());
592 AssignLayouts(m.get(), &computation_layout);
593 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
594 transpose->shape(), {2, 3, 0, 1}));
595 }
596 // Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest,TransposeToBitcastToUser)597 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
598 auto builder = HloComputation::Builder(TestName());
599 Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
600 auto constant = builder.AddInstruction(
601 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
602 auto broadcast = builder.AddInstruction(
603 HloInstruction::CreateBroadcast(input_shape, constant, {}));
604 auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
605 ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
606 auto m = CreateNewVerifiedModule();
607 HloComputation* computation =
608 m->AddEntryComputation(builder.Build(transpose));
609 ComputationLayout computation_layout(computation->ComputeProgramShape());
610 AssignLayouts(m.get(), &computation_layout);
611 EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
612 transpose->shape(), {2, 3, 0, 1}));
613 }
614
615 // TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,TransposeIsBitcastFail)616 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
617 auto builder = HloComputation::Builder(TestName());
618 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
619 Shape input_shape_with_layout(input_shape);
620 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
621 auto param = builder.AddInstruction(
622 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
623 auto hlo = builder.AddInstruction(
624 HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
625 // Clear the default layout assigned to the instruction.
626 LayoutUtil::ClearLayout(hlo->mutable_shape());
627 EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
628 hlo->shape(), hlo->dimensions()),
629 "LayoutUtil::HasLayout");
630 }
631
632 // ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,ReshapeIsBitcastFail)633 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
634 auto builder = HloComputation::Builder(TestName());
635 Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
636 Shape input_shape_with_layout(input_shape);
637 *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
638 auto param = builder.AddInstruction(
639 HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
640 auto hlo =
641 builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
642 // Clear the default layout assigned to the instruction.
643 LayoutUtil::ClearLayout(hlo->mutable_shape());
644 EXPECT_DEATH(
645 ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
646 "LayoutUtil::HasLayout");
647 }
648
649 // Check that the computation below doesn't crash the compiler.
650 //
651 // Within a fusion computation, only the parameters and result get assigned a
652 // layout. When we run the algebraic simplifier on this computation post layout
653 // assignment, it should not call TransposeIsBitcast on the `transpose` node
654 // inside the fusion computation as TransposeIsBitcast checks both input_shape
655 // and output_shape have layouts.
TEST_F(LayoutAssignmentTest,TransposeWithinFusionDoesNotCrash)656 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
657 const char* module_str = R"(
658 HloModule test_module
659
660 fused_computation {
661 param_1 = f32[2,2,2]{2,1,0} parameter(1)
662 transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
663 reduce_1 = f32[] parameter(0)
664 broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
665 ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
666 }
667
668 ENTRY entry_computation {
669 fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
670 reduce.1 = f32[] parameter(0)
671 fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
672 ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
673 }
674 )";
675
676 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
677 ParseAndReturnVerifiedModule(module_str));
678 std::unique_ptr<HloModule> compiled_module =
679 backend()
680 .compiler()
681 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
682 /*device_allocator=*/nullptr)
683 .ConsumeValueOrDie();
684
685 EXPECT_EQ(Status::OK(), backend()
686 .compiler()
687 ->RunBackend(std::move(compiled_module),
688 backend().default_stream_executor(),
689 /*device_allocator=*/nullptr)
690 .status());
691 }
692
693 // A GTE inside of a fusion node inherits the layout of its operand (which
694 // should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest,GTEInheritsLayoutFromOperand)695 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
696 const char* module_str = R"(
697 HloModule test_module
698
699 fused_computation {
700 fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
701 gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
702 gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
703 gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
704 gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
705 add = f32[2,2,2] add(gte1a, gte1b)
706 ROOT fresult = f32[2,2,2] add(gte0, add)
707 }
708
709 ENTRY entry_computation {
710 param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
711 ROOT fusion =
712 f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
713 }
714 )";
715
716 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
717 ParseAndReturnVerifiedModule(module_str));
718 ComputationLayout computation_layout(
719 m->entry_computation()->ComputeProgramShape());
720 Shape param_shape = ShapeUtil::MakeTupleShape(
721 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
722 ShapeUtil::MakeTupleShape({
723 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
724 ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
725 })});
726 TF_ASSERT_OK(
727 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
728 param_shape));
729 computation_layout.mutable_result_layout()->ResetLayout(
730 LayoutUtil::MakeLayout({2, 1, 0}));
731 AssignLayouts(m.get(), &computation_layout);
732
733 EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
734 EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
735 EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
736 EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
737 EXPECT_THAT(FindInstruction(m.get(), "gte1")
738 ->shape()
739 .tuple_shapes(0)
740 .layout()
741 .minor_to_major(),
742 ElementsAre(1, 2, 0));
743 EXPECT_THAT(FindInstruction(m.get(), "gte1")
744 ->shape()
745 .tuple_shapes(1)
746 .layout()
747 .minor_to_major(),
748 ElementsAre(2, 0, 1));
749 }
750
TEST_F(LayoutAssignmentTest,ConditionalAsymmetricLayout)751 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
752 auto builder = HloComputation::Builder(TestName());
753 auto m = CreateNewVerifiedModule();
754 Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
755 Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
756 Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
757
758 auto param0 = builder.AddInstruction(
759 HloInstruction::CreateParameter(0, shape, "param0"));
760 auto param1 = builder.AddInstruction(
761 HloInstruction::CreateParameter(1, shape, "param1"));
762 auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
763 2, ShapeUtil::MakeShape(PRED, {}), "param2"));
764 auto tuple =
765 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
766
767 auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
768 {
769 auto param = true_builder.AddInstruction(
770 HloInstruction::CreateParameter(0, tshape, "param"));
771 auto gte0 = true_builder.AddInstruction(
772 HloInstruction::CreateGetTupleElement(shape, param, 0));
773 auto gte1 = true_builder.AddInstruction(
774 HloInstruction::CreateGetTupleElement(shape, param, 1));
775 auto add = true_builder.AddInstruction(
776 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
777 true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
778 }
779 HloComputation* true_computation =
780 m->AddEmbeddedComputation(true_builder.Build());
781
782 auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
783 {
784 Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
785 false_builder.AddInstruction(
786 HloInstruction::CreateParameter(0, tshape, "param"));
787 // Using infeed as layout assignment does not mess up with it.
788 auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
789 auto infeed = false_builder.AddInstruction(
790 HloInstruction::CreateInfeed(xshape, token, ""));
791 auto infeed_data = false_builder.AddInstruction(
792 HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
793 false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
794 }
795 HloComputation* false_computation =
796 m->AddEmbeddedComputation(false_builder.Build());
797 builder.AddInstruction(HloInstruction::CreateConditional(
798 result_tshape, pred, tuple, true_computation, tuple, false_computation));
799
800 HloComputation* computation = m->AddEntryComputation(builder.Build());
801 ComputationLayout computation_layout(computation->ComputeProgramShape());
802
803 AssignLayouts(m.get(), &computation_layout);
804
805 const HloInstruction* true_root = true_computation->root_instruction();
806 const HloInstruction* false_root = false_computation->root_instruction();
807 EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
808 EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
809
810 const HloInstruction* true_result = true_root->operand(0);
811 const HloInstruction* false_result = false_root->operand(0);
812 EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
813 false_result->shape().layout()));
814 EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
815 }
816
TEST_F(LayoutAssignmentTest,InternalErrorOnBitcast)817 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
818 auto builder = HloComputation::Builder(TestName());
819 auto constant0 = builder.AddInstruction(
820 HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
821 {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
822 builder.AddInstruction(
823 HloInstruction::CreateBitcast(constant0->shape(), constant0));
824 auto m = CreateNewVerifiedModule();
825 m->AddEntryComputation(builder.Build());
826
827 ComputationLayout computation_layout(
828 m->entry_computation()->ComputeProgramShape());
829 LayoutAssignment layout_assignment(&computation_layout);
830 Status error_status = layout_assignment.Run(m.get()).status();
831 EXPECT_FALSE(error_status.ok());
832 EXPECT_THAT(
833 error_status.error_message(),
834 ::testing::HasSubstr(
835 "Unexpected bitcast operation seen during layout assignment"));
836 }
837
TEST_F(LayoutAssignmentTest,ChannelLayoutMismatch)838 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
839 // Pin non matching layouts to parameter and root.
840 const char* module_str = R"(
841 HloModule test_module
842
843 ENTRY entry_computation {
844 param = (f32[2,2]) parameter(0)
845 gte = f32[2,2] get-tuple-element(param), index=0
846 token0 = token[] after-all()
847 recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
848 recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
849 sharding={maximal device=1}
850 ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
851 send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
852 sharding={maximal device=0}
853 send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
854 }
855 )";
856
857 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
858 ParseAndReturnVerifiedModule(module_str));
859 ComputationLayout computation_layout(
860 m->entry_computation()->ComputeProgramShape());
861 Shape param_shape = ShapeUtil::MakeTupleShape(
862 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
863 TF_ASSERT_OK(
864 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
865 param_shape));
866 computation_layout.mutable_result_layout()->ResetLayout(
867 LayoutUtil::MakeLayout({1, 0}));
868
869 ChannelLayoutConstraints channel_constraints;
870 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
871
872 EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
873 FindInstruction(m.get(), "recv")->shape()));
874 }
875
TEST_F(LayoutAssignmentTest,AllReduceLayoutMissmatch)876 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
877 // Pin non matching layouts to parameter and root.
878 const char* module_str = R"(
879 HloModule test_module
880
881 add {
882 lhs = f32[] parameter(0)
883 rhs = f32[] parameter(1)
884 ROOT add = f32[] add(lhs, rhs)
885 }
886
887 ENTRY entry_computation {
888 param = (f32[2,2]) parameter(0)
889 gte = f32[2,2] get-tuple-element(param), index=0
890 ar.0 = f32[2,2] all-reduce(gte),
891 channel_id=1, replica_groups={{0}}, to_apply=add,
892 sharding={maximal device=0}
893 const = f32[2,2] constant({{0,1},{2,3}})
894 ROOT ar.1 = f32[2,2] all-reduce(const),
895 channel_id=1, replica_groups={{0}}, to_apply=add,
896 sharding={maximal device=1}
897 })";
898 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
899 ParseAndReturnVerifiedModule(module_str));
900 ComputationLayout computation_layout(
901 m->entry_computation()->ComputeProgramShape());
902 Shape param_shape = ShapeUtil::MakeTupleShape(
903 {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
904 TF_ASSERT_OK(
905 computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
906 param_shape));
907 computation_layout.mutable_result_layout()->ResetLayout(
908 LayoutUtil::MakeLayout({1, 0}));
909
910 ChannelLayoutConstraints channel_constraints;
911 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
912
913 EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
914 EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
915 EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
916 const HloInstruction* root = m->entry_computation()->root_instruction();
917 EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
918 }
919
TEST_F(LayoutAssignmentTest,CopySliceOperandToAvoidImplicitLayoutChange)920 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
921 const char* module_str = R"(
922 HloModule CopySliceOperandToAvoidImplicitLayoutChange
923
924 ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
925 par0 = f32[3,4]{1,0} parameter(0)
926 par1 = f32[4,5]{0,1} parameter(1)
927 slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
928 ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
929 }
930 )";
931
932 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
933 ParseAndReturnVerifiedModule(module_str));
934 auto compiled_module =
935 backend()
936 .compiler()
937 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
938 /*device_allocator=*/nullptr)
939 .ConsumeValueOrDie();
940 HloInstruction* root =
941 compiled_module->entry_computation()->root_instruction();
942 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
943 EXPECT_THAT(
944 root,
945 GmockMatch(m::Add(
946 m::Parameter(),
947 m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
948 }
949
TEST_F(LayoutAssignmentTest,CopyDSliceOperandToAvoidImplicitLayoutChange)950 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
951 const char* module_str = R"(
952 HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
953
954 ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
955 par0 = f32[3,4]{1,0} parameter(0)
956 par1 = f32[4,5]{0,1} parameter(1)
957 par2 = s32[] parameter(2)
958 par3 = s32[] parameter(3)
959 dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
960 ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
961 }
962 )";
963
964 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
965 ParseAndReturnVerifiedModule(module_str));
966 auto compiled_module =
967 backend()
968 .compiler()
969 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
970 /*device_allocator=*/nullptr)
971 .ConsumeValueOrDie();
972 HloInstruction* root =
973 compiled_module->entry_computation()->root_instruction();
974 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
975 EXPECT_THAT(root,
976 GmockMatch(m::Add(
977 m::Parameter(),
978 m::DynamicSlice(
979 m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
980 m::Parameter(2), m::Parameter(3)))));
981 }
982
TEST_F(LayoutAssignmentTest,CopyConcatOperandToAvoidImplicitLayoutChange)983 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
984 const char* module_str = R"(
985 HloModule CopyConcatOperandToAvoidImplicitLayoutChange
986
987 ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
988 par0 = f32[3,8]{1,0} parameter(0)
989 par1 = f32[3,5]{0,1} parameter(1)
990 par2 = f32[3,3]{1,0} parameter(2)
991 concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
992 dimensions={1}
993 ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
994 }
995 )";
996
997 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
998 ParseAndReturnVerifiedModule(module_str));
999 auto compiled_module =
1000 backend()
1001 .compiler()
1002 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1003 /*device_allocator=*/nullptr)
1004 .ConsumeValueOrDie();
1005 HloInstruction* root =
1006 compiled_module->entry_computation()->root_instruction();
1007 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
1008 EXPECT_THAT(
1009 root,
1010 GmockMatch(m::Add(
1011 m::Parameter(),
1012 m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
1013 m::Parameter(2)))));
1014 }
1015
TEST_F(LayoutAssignmentTest,ConvolutionOperandWithImplicitLayoutChangeNotCopied)1016 TEST_F(LayoutAssignmentTest,
1017 ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
1018 const char* module_str = R"(
1019 HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
1020
1021 ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
1022 par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
1023 par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
1024 ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
1025 window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
1026 feature_group_count=1
1027 }
1028 )";
1029
1030 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1031 ParseAndReturnVerifiedModule(module_str));
1032 auto compiled_module =
1033 backend()
1034 .compiler()
1035 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1036 /*device_allocator=*/nullptr)
1037 .ConsumeValueOrDie();
1038 HloInstruction* root =
1039 compiled_module->entry_computation()->root_instruction();
1040 EXPECT_THAT(root,
1041 GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
1042 }
1043
TEST_F(LayoutAssignmentTest,PropagatingLayoutFromResultToOperand)1044 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
1045 const char* module_str = R"(
1046 HloModule PropagatingLayoutFromResultToOperand
1047
1048 ENTRY PropagatingLayoutFromResultToOperand {
1049 par0 = f32[4,5]{1,0} parameter(0)
1050 ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
1051 }
1052 )";
1053
1054 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1055 ParseAndReturnVerifiedModule(module_str));
1056 auto compiled_module =
1057 backend()
1058 .compiler()
1059 ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1060 /*device_allocator=*/nullptr)
1061 .ConsumeValueOrDie();
1062 HloInstruction* root =
1063 compiled_module->entry_computation()->root_instruction();
1064 Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
1065 EXPECT_THAT(root,
1066 GmockMatch(m::Slice(
1067 m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
1068 }
1069
TEST_F(LayoutAssignmentTest,TupleCopyOnLayoutMismatch)1070 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
1071 // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
1072 // The mismatch forces a copy of the tuple. The tuple contains a token, so
1073 // layout assignment will fail if it tries to copy the whole tuple.
1074 const char* module_str = R"(
1075 HloModule TupleCopyOnLayoutMismatch
1076
1077 condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
1078 tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1079 counter.1 = s32[] get-tuple-element(tup.1), index=0
1080 five = s32[] constant(5)
1081 ROOT lt = pred[] compare(counter.1, five), direction=LT
1082 }
1083
1084 body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
1085 tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1086 counter.2 = s32[] get-tuple-element(tup.2), index=0
1087 tok.2 = token[] get-tuple-element(tup.2), index=1
1088
1089 ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
1090 next_tok = token[] get-tuple-element(ifeed.2), index=1
1091 next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
1092
1093 one = s32[] constant(1)
1094 next_counter = s32[] add(counter.2, one)
1095 ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
1096 }
1097
1098 ENTRY main () -> f32[512,1024]{0,1} {
1099 start_tok = token[] after-all()
1100
1101 ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
1102 itok = token[] get-tuple-element(ifeed.3), index=1
1103 ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
1104
1105 zero = s32[] constant(0)
1106 itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
1107
1108 loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
1109 ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
1110 }
1111 )";
1112
1113 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1114 ParseAndReturnVerifiedModule(module_str));
1115 ComputationLayout computation_layout(
1116 m->entry_computation()->ComputeProgramShape());
1117
1118 // Sanity check to verify that there's a layout mismatch.
1119 EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1120 EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1121
1122 AssignLayouts(m.get(), &computation_layout);
1123
1124 // Make sure that layout assignment did not magically eliminate the mismatch,
1125 // in which case the test didn't prove anything.
1126 EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1127 EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1128 }
1129
TEST_F(LayoutAssignmentTest,CustomCallNotLayoutConstrained)1130 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
1131 const char* module_str = R"(
1132 HloModule CustomCallNotLayoutConstrained
1133
1134 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
1135 %p = f32[42,2,3] parameter(0)
1136 ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
1137 }
1138 )";
1139 // Try with a couple different layouts. In each case the custom calls operand
1140 // and result layout should match that of the computation.
1141 {
1142 TF_ASSERT_OK_AND_ASSIGN(
1143 std::unique_ptr<VerifiedHloModule> m,
1144 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1145 ComputationLayout computation_layout = m->entry_computation_layout();
1146 *computation_layout.mutable_parameter_layout(0) =
1147 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
1148 *computation_layout.mutable_result_layout() = ShapeLayout(
1149 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
1150 AssignLayouts(m.get(), &computation_layout);
1151
1152 HloInstruction* root = m->entry_computation()->root_instruction();
1153 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1154 ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
1155 ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
1156 }
1157 {
1158 TF_ASSERT_OK_AND_ASSIGN(
1159 std::unique_ptr<VerifiedHloModule> m,
1160 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1161 ComputationLayout computation_layout = m->entry_computation_layout();
1162 *computation_layout.mutable_parameter_layout(0) =
1163 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
1164 *computation_layout.mutable_result_layout() = ShapeLayout(
1165 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
1166 AssignLayouts(m.get(), &computation_layout);
1167
1168 HloInstruction* root = m->entry_computation()->root_instruction();
1169 ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1170 ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
1171 ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
1172 }
1173 }
1174
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrained)1175 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
1176 const char* module_str = R"(
1177 HloModule CustomCallLayoutConstrained
1178
1179 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1180 %p0 = f32[4,4] parameter(0)
1181 %p1 = f32[2,3] parameter(1)
1182 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
1183 }
1184 )";
1185 TF_ASSERT_OK_AND_ASSIGN(
1186 std::unique_ptr<VerifiedHloModule> m,
1187 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1188 ComputationLayout computation_layout = m->entry_computation_layout();
1189 *computation_layout.mutable_parameter_layout(0) =
1190 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1191 *computation_layout.mutable_parameter_layout(1) =
1192 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1193 *computation_layout.mutable_result_layout() = ShapeLayout(
1194 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1195 AssignLayouts(m.get(), &computation_layout);
1196
1197 // The custom call should be partially encapsulated in kCopy instructions
1198 // because of the layout mismatches.
1199 ASSERT_THAT(m->entry_computation()->root_instruction(),
1200 GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
1201
1202 const HloInstruction* custom_call =
1203 m->entry_computation()->root_instruction()->operand(0);
1204 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1205 ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
1206 ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1207 }
1208
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedAliasedOutput)1209 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedAliasedOutput) {
1210 const char* module_str = R"(
1211 HloModule customcall.4
1212
1213 ENTRY %customcall.4 (parameter.1: f32[8,128], parameter.2: f32[8,128]) -> f32[8,128] {
1214 %parameter.1 = f32[8,128]{1,0} parameter(0)
1215 %parameter.2 = f32[8,128]{1,0} parameter(1)
1216 ROOT %custom-call.3 = f32[8,128]{1,0} custom-call(f32[8,128]{1,0} %parameter.1, f32[8,128]{1,0} %parameter.2), custom_call_target="gpu_example_custom_call", operand_layout_constraints={f32[8,128]{1,0}, f32[8,128]{1,0}}, custom_call_has_side_effect=true, output_to_operand_aliasing={{}: (0, {})}
1217 })";
1218 TF_ASSERT_OK_AND_ASSIGN(
1219 std::unique_ptr<VerifiedHloModule> m,
1220 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1221 ComputationLayout computation_layout = m->entry_computation_layout();
1222 *computation_layout.mutable_parameter_layout(0) =
1223 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1224 *computation_layout.mutable_parameter_layout(1) =
1225 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1226 *computation_layout.mutable_result_layout() =
1227 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {8, 128}, {1, 0}));
1228 AssignLayouts(m.get(), &computation_layout);
1229
1230 const HloInstruction* custom_call =
1231 m->entry_computation()->root_instruction();
1232 ExpectLayoutIs(custom_call->shape(), {1, 0});
1233 ExpectLayoutIs(custom_call->operand(0)->shape(), {1, 0});
1234 ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1235 }
1236
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedZeroOperands)1237 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
1238 const char* module_str = R"(
1239 HloModule CustomCallLayoutConstrainedZeroOperands
1240
1241 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
1242 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1243 }
1244 )";
1245 TF_ASSERT_OK_AND_ASSIGN(
1246 std::unique_ptr<VerifiedHloModule> m,
1247 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1248 ComputationLayout computation_layout = m->entry_computation_layout();
1249 *computation_layout.mutable_result_layout() = ShapeLayout(
1250 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1251 AssignLayouts(m.get(), &computation_layout);
1252
1253 ASSERT_THAT(m->entry_computation()->root_instruction(),
1254 GmockMatch(m::Copy(m::CustomCall())));
1255
1256 const HloInstruction* custom_call =
1257 m->entry_computation()->root_instruction()->operand(0);
1258 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1259 }
1260
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleOperand)1261 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
1262 const char* module_str = R"(
1263 HloModule CustomCallLayoutConstrainedTupleOperand
1264
1265 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1266 %p0 = f32[4,4] parameter(0)
1267 %p1 = f32[2,3] parameter(1)
1268 %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
1269 ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
1270 }
1271 )";
1272 TF_ASSERT_OK_AND_ASSIGN(
1273 std::unique_ptr<VerifiedHloModule> m,
1274 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1275 ComputationLayout computation_layout = m->entry_computation_layout();
1276 *computation_layout.mutable_parameter_layout(0) =
1277 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1278 *computation_layout.mutable_parameter_layout(1) =
1279 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1280 *computation_layout.mutable_result_layout() = ShapeLayout(
1281 ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1282 AssignLayouts(m.get(), &computation_layout);
1283
1284 HloInstruction* root = m->entry_computation()->root_instruction();
1285 ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
1286
1287 ASSERT_THAT(m->entry_computation()->root_instruction(),
1288 GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
1289
1290 const HloInstruction* custom_call =
1291 m->entry_computation()->root_instruction()->operand(0);
1292 ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1293 ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
1294 }
1295
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleResult)1296 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
1297 const char* module_str = R"(
1298 HloModule CustomCallLayoutConstrainedTupleResult
1299
1300 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
1301 %p0 = f32[4,4] parameter(0)
1302 ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
1303 }
1304 )";
1305 // Try with a couple different layouts. In each case the custom calls operand
1306 // and result layout should match that of the computation.
1307 TF_ASSERT_OK_AND_ASSIGN(
1308 std::unique_ptr<VerifiedHloModule> m,
1309 ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1310 ComputationLayout computation_layout = m->entry_computation_layout();
1311 *computation_layout.mutable_parameter_layout(0) =
1312 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1313 *computation_layout.mutable_result_layout() =
1314 ShapeLayout(ShapeUtil::MakeTupleShape(
1315 {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
1316 ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
1317 AssignLayouts(m.get(), &computation_layout);
1318
1319 ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
1320
1321 const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
1322 ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
1323 }
1324
AssignLayoutsToComputation(HloModule * m,ChannelLayoutConstraints * channel_constraints=nullptr)1325 Status AssignLayoutsToComputation(
1326 HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
1327 if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
1328 m->mutable_entry_computation_layout()
1329 ->mutable_result_layout()
1330 ->SetToDefaultLayout();
1331 }
1332 LayoutAssignment layout_assignment(m->mutable_entry_computation_layout(),
1333 channel_constraints);
1334 return layout_assignment.Run(m).status();
1335 }
1336
TEST_F(LayoutAssignmentTest,OverwriteDiamondShapedConstraintsX)1337 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
1338 // Check that we handle a diamond-shaped graph correctly.
1339 // transpose
1340 // / \
1341 // add |
1342 // \ /
1343 // tuple
1344
1345 auto b = HloComputation::Builder(TestName());
1346 Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
1347 Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
1348 auto param0 =
1349 b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
1350 auto param1 =
1351 b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
1352 auto transpose =
1353 b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
1354 auto add = b.AddInstruction(
1355 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
1356 b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
1357 auto m = CreateNewVerifiedModule();
1358 m->AddEntryComputation(b.Build());
1359 Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
1360 Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
1361 *m->mutable_entry_computation_layout()->mutable_result_layout() =
1362 ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
1363 const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
1364 ForceParameterLayout(m.get(), 0, r2_dim0major);
1365 ForceParameterLayout(m.get(), 1, r2_dim0major);
1366 TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
1367
1368 EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
1369 EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
1370 ElementsAre(1, 0));
1371 EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
1372 ElementsAre(1, 0));
1373
1374 EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
1375 }
1376
1377 // Tests that the layout assignment supports layout-constrained all-reduce with
1378 // different operand layouts (b/146056839).
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllReduce)1379 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllReduce) {
1380 const char* module_str = R"(
1381 HloModule test_module
1382
1383 add {
1384 lhs = f32[] parameter(0)
1385 rhs = f32[] parameter(1)
1386 ROOT add = f32[] add(lhs, rhs)
1387 }
1388
1389 ENTRY entry_computation {
1390 param = (f32[8,4]{0,1}, f32[16,2]{0,1}) parameter(0)
1391 gte0 = f32[8,4] get-tuple-element(param), index=0
1392 gte1 = f32[16,2] get-tuple-element(param), index=1
1393 crs = (f32[8,4]{0,1}, f32[16,2]{1,0}) all-reduce(gte0, gte1),
1394 replica_groups={}, constrain_layout=true, to_apply=add
1395 gte2 = f32[8,4] get-tuple-element(crs), index=0
1396 gte3 = f32[16,2] get-tuple-element(crs), index=1
1397 ROOT result = (f32[8,4]{1,0}, f32[16,2]{1,0}) tuple(gte2, gte3)
1398 }
1399 )";
1400
1401 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1402 ParseAndReturnVerifiedModule(module_str));
1403 ComputationLayout computation_layout(
1404 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1405
1406 ChannelLayoutConstraints channel_constraints;
1407 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1408
1409 const HloInstruction* crs = FindInstruction(m.get(), "crs");
1410 ExpectTupleLayoutIs(crs->shape(), {{0, 1}, {1, 0}});
1411 ExpectLayoutIs(crs->operand(0)->shape(), {0, 1});
1412 ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
1413 }
1414
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllToAll)1415 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
1416 const char* module_str = R"(
1417 HloModule test_module
1418
1419 add {
1420 lhs = f32[] parameter(0)
1421 rhs = f32[] parameter(1)
1422 ROOT add = f32[] add(lhs, rhs)
1423 }
1424
1425 ENTRY entry_computation {
1426 param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
1427 gte0 = f32[16,4] get-tuple-element(param), index=0
1428 gte1 = f32[16,4] get-tuple-element(param), index=1
1429 alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
1430 replica_groups={{0,1}}, constrain_layout=true, to_apply=add
1431 gte2 = f32[16,4] get-tuple-element(alltoall), index=0
1432 gte3 = f32[16,4] get-tuple-element(alltoall), index=1
1433 ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
1434 }
1435 )";
1436
1437 TF_ASSERT_OK_AND_ASSIGN(
1438 std::unique_ptr<HloModule> m,
1439 ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1440 ComputationLayout computation_layout(
1441 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1442
1443 ChannelLayoutConstraints channel_constraints;
1444 AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1445
1446 const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
1447 ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
1448 ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
1449 ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
1450 }
1451
TEST_F(LayoutAssignmentTest,DynamicRoot)1452 TEST_F(LayoutAssignmentTest, DynamicRoot) {
1453 const char* module_str = R"(
1454 HloModule test_module
1455
1456 ENTRY entry_computation {
1457 param = f32[1,<=16]{0,1} parameter(0)
1458 ROOT abs = f32[1,<=16]{0,1} abs(param)
1459 }
1460 )";
1461
1462 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1463 ParseAndReturnVerifiedModule(module_str));
1464 ComputationLayout computation_layout(
1465 m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1466 computation_layout.mutable_result_layout()->ClearDynamicShape();
1467
1468 AssignLayouts(m.get(), &computation_layout);
1469
1470 const HloInstruction* abs = FindInstruction(m.get(), "abs");
1471 ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
1472 ExpectLayoutIs(abs->shape(), {0, 1});
1473 EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
1474 }
1475
1476 // Test the ability to avoid copying across computations by reversing
1477 // computation traversal order.
TEST_F(LayoutAssignmentTest,ReverseComputationOrderAvoidCopy)1478 TEST_F(LayoutAssignmentTest, ReverseComputationOrderAvoidCopy) {
1479 const char* module_str = R"(
1480 HloModule ComputationLayoutAvoidCopy
1481
1482 call_1 {
1483 %arg_tuple.1 = (f32[93184,4]) parameter(0)
1484 %get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
1485 ROOT %reshape.8494 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
1486 }
1487
1488 on_true {
1489 %arg_tuple.1 = (f32[93184,4]) parameter(0)
1490 %get-tuple-element.1 = f32[93184,4] get-tuple-element(%arg_tuple.1), index=0
1491 ROOT %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.1)
1492 }
1493
1494 on_false {
1495 %arg_tuple.2 = (f32[93184,4]) parameter(0)
1496 %get-tuple-element.3 = f32[93184,4] get-tuple-element(%arg_tuple.2), index=0
1497 %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0}%get-tuple-element.3)
1498 ROOT %add = f32[2,512,364] add(%reshape.9717, %reshape.9717)
1499 }
1500
1501 ENTRY main {
1502 pred.1 = pred[] parameter(0)
1503 arg.2 = f32[93184,4]{1,0} parameter(1)
1504 arg_tuple.11 = (f32[93184,4]{1,0}) tuple(arg.2)
1505 call.1 = f32[2,512,364] call(arg_tuple.11), to_apply=call_1
1506 conditional = f32[2,512,364] conditional(pred.1, arg_tuple.11, arg_tuple.11),
1507 true_computation=on_true, false_computation=on_false
1508 ROOT add = f32[2,512,364] add(call.1, conditional)
1509 }
1510 )";
1511
1512 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1513 ParseAndReturnVerifiedModule(module_str));
1514 std::cerr << m->ToString();
1515 ComputationLayout computation_layout(
1516 m->entry_computation()->ComputeProgramShape());
1517 *computation_layout.mutable_parameter_layout(0) =
1518 ShapeLayout(ShapeUtil::MakeShape(PRED, {}));
1519 *computation_layout.mutable_parameter_layout(1) =
1520 ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {93184, 4}, {0, 1}));
1521 *computation_layout.mutable_result_layout() = ShapeLayout(
1522 ShapeUtil::MakeShapeWithLayout(F32, {2, 512, 364}, {0, 1, 2}));
1523 std::cerr << computation_layout.ToString();
1524 ChannelLayoutConstraints channel_constraints;
1525 LayoutAssignment layout_assignment(
1526 &computation_layout,
1527 /*channel_constraints=*/&channel_constraints,
1528 /* reverse_computation_order = */ true);
1529 EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
1530 std::cerr << m->ToString();
1531 const HloInstruction* call_1 = FindInstruction(m.get(), "reshape.8494");
1532 ExpectLayoutIs(call_1->shape(), {0, 1, 2});
1533 const HloInstruction* on_true = FindInstruction(m.get(), "reshape.8493");
1534 ExpectLayoutIs(on_true->shape(), {0, 1, 2});
1535 const HloInstruction* on_false = FindInstruction(m.get(), "reshape.9717");
1536 ExpectLayoutIs(on_false->shape(), {0, 1, 2});
1537 }
1538
1539 } // namespace
1540 } // namespace xla
1541