• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <initializer_list>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/test_helpers.h"
39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 
46 namespace xla {
47 namespace {
48 
49 namespace m = xla::match;
50 using ::testing::ElementsAre;
51 
52 class LayoutAssignmentTest : public HloTestBase {
53  protected:
AssignLayouts(HloModule * m,ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints=nullptr)54   void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
55                      ChannelLayoutConstraints* channel_constraints = nullptr) {
56     LayoutAssignment layout_assignment(
57         entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
58         /*channel_constraints=*/channel_constraints);
59     EXPECT_IS_OK(layout_assignment.Run(m).status());
60   }
61 
LayoutOf(HloModule * m,absl::string_view name)62   std::vector<int64> LayoutOf(HloModule* m, absl::string_view name) {
63     auto minor_to_major =
64         FindInstruction(m, name)->shape().layout().minor_to_major();
65     return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
66   }
67 
ExpectLayoutIs(const Shape & shape,absl::Span<const int64> minor_to_major)68   void ExpectLayoutIs(const Shape& shape,
69                       absl::Span<const int64> minor_to_major) {
70     const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
71     EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
72         << "Expected layout " << expected << ", actual " << shape.layout();
73   }
74 
ExpectTupleLayoutIs(const Shape & shape,std::initializer_list<absl::Span<const int64>> minor_to_majors)75   void ExpectTupleLayoutIs(
76       const Shape& shape,
77       std::initializer_list<absl::Span<const int64>> minor_to_majors) {
78     int i = 0;
79     for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
80       const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
81       const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
82       EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
83           << "Expected tuple element " << i << " layout " << expected
84           << ", actual " << actual;
85       ++i;
86     }
87   }
88 };
89 
TEST_F(LayoutAssignmentTest,ComputationLayout)90 TEST_F(LayoutAssignmentTest, ComputationLayout) {
91   // Verify the layouts of the root and parameter instructions of a computation
92   // match the ComputationLayout for two different layouts.
93   std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
94   for (auto& minor_to_major : minor_to_majors) {
95     auto builder = HloComputation::Builder(TestName());
96     Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
97     auto param0 = builder.AddInstruction(
98         HloInstruction::CreateParameter(0, ashape, "param0"));
99     auto param1 = builder.AddInstruction(
100         HloInstruction::CreateParameter(1, ashape, "param1"));
101     auto add = builder.AddInstruction(
102         HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
103     auto m = CreateNewVerifiedModule();
104     HloComputation* computation = m->AddEntryComputation(builder.Build());
105 
106     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
107     Shape shape(ashape);
108     *shape.mutable_layout() = layout;
109     const ShapeLayout shape_layout(shape);
110 
111     ComputationLayout computation_layout(computation->ComputeProgramShape());
112     *computation_layout.mutable_parameter_layout(0) = shape_layout;
113     *computation_layout.mutable_parameter_layout(1) = shape_layout;
114     *computation_layout.mutable_result_layout() = shape_layout;
115     AssignLayouts(m.get(), &computation_layout);
116     EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
117     EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
118     EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
119   }
120 }
121 
TEST_F(LayoutAssignmentTest,ComputationLayoutMixedLayout)122 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
123   // Verify the layouts of the root and parameter instructions of a computation
124   // match the ComputationLayout which has mixed layout.
125   auto builder = HloComputation::Builder(TestName());
126   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
127   auto param0 = builder.AddInstruction(
128       HloInstruction::CreateParameter(0, ashape, "param0"));
129   auto param1 = builder.AddInstruction(
130       HloInstruction::CreateParameter(1, ashape, "param1"));
131   builder.AddInstruction(
132       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
133   auto m = CreateNewVerifiedModule();
134   HloComputation* computation = m->AddEntryComputation(builder.Build());
135 
136   Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
137   Shape col_major_shape(ashape);
138   *col_major_shape.mutable_layout() = col_major_layout;
139   const ShapeLayout col_major(col_major_shape);
140 
141   Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
142   Shape row_major_shape(ashape);
143   *row_major_shape.mutable_layout() = row_major_layout;
144   const ShapeLayout row_major(row_major_shape);
145 
146   ComputationLayout computation_layout(computation->ComputeProgramShape());
147   *computation_layout.mutable_parameter_layout(0) = col_major;
148   *computation_layout.mutable_parameter_layout(1) = row_major;
149   *computation_layout.mutable_result_layout() = col_major;
150 
151   AssignLayouts(m.get(), &computation_layout);
152   EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
153   EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
154   EXPECT_TRUE(LayoutUtil::Equal(
155       col_major_layout, computation->root_instruction()->shape().layout()));
156 }
157 
TEST_F(LayoutAssignmentTest,FusionInstruction)158 TEST_F(LayoutAssignmentTest, FusionInstruction) {
159   // Verify that the layout of the fused parameters in a fusion instruction
160   // match that of the fusion operands. Other fused instructions should have no
161   // layout.
162   std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
163   for (auto& minor_to_major : minor_to_majors) {
164     auto builder = HloComputation::Builder(TestName());
165     auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
166         {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
167     auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
168         {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
169     Shape ashape = constant_literal1.shape();
170 
171     auto constant1 = builder.AddInstruction(
172         HloInstruction::CreateConstant(std::move(constant_literal1)));
173     auto constant2 = builder.AddInstruction(
174         HloInstruction::CreateConstant(std::move(constant_literal2)));
175     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
176         ashape, HloOpcode::kAdd, constant1, constant2));
177     auto negate1 = builder.AddInstruction(
178         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
179     auto negate2 = builder.AddInstruction(
180         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
181 
182     auto m = CreateNewVerifiedModule();
183     HloComputation* computation = m->AddEntryComputation(builder.Build());
184 
185     auto fusion = computation->CreateFusionInstruction(
186         {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
187 
188     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
189     Shape shape(ashape);
190     *shape.mutable_layout() = layout;
191     const ShapeLayout shape_layout(shape);
192 
193     ComputationLayout computation_layout(computation->ComputeProgramShape());
194     *computation_layout.mutable_result_layout() = shape_layout;
195 
196     AssignLayouts(m.get(), &computation_layout);
197 
198     EXPECT_TRUE(LayoutUtil::Equal(
199         layout, fusion->fused_parameter(0)->shape().layout()));
200     EXPECT_TRUE(LayoutUtil::Equal(
201         layout, fusion->fused_parameter(1)->shape().layout()));
202     EXPECT_TRUE(LayoutUtil::Equal(
203         layout, fusion->fused_expression_root()->shape().layout()));
204 
205     // Inner fused node should not have layout.
206     EXPECT_FALSE(LayoutUtil::HasLayout(
207         fusion->fused_expression_root()->operand(0)->shape()));
208   }
209 }
210 
TEST_F(LayoutAssignmentTest,TupleLayout)211 TEST_F(LayoutAssignmentTest, TupleLayout) {
212   // Verify the layouts of a tuple are assigned properly (the element layouts
213   // match their source).
214   auto builder = HloComputation::Builder(TestName());
215   auto constant0 = builder.AddInstruction(
216       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
217           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
218   auto constant1 = builder.AddInstruction(
219       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
220           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
221   auto tuple = builder.AddInstruction(
222       HloInstruction::CreateTuple({constant0, constant1}));
223 
224   // To avoid having to construct a tuple layout in the ComputationLayout below,
225   // make the result of the instruction be an array.
226   auto get_element0 = builder.AddInstruction(
227       HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
228   auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
229       constant0->shape(), HloOpcode::kNegate, get_element0));
230 
231   auto m = CreateNewVerifiedModule();
232   m->AddEntryComputation(builder.Build());
233 
234   ComputationLayout computation_layout(
235       m->entry_computation()->ComputeProgramShape());
236 
237   AssignLayouts(m.get(), &computation_layout);
238 
239   EXPECT_TRUE(
240       LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
241 
242   EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
243   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
244       negate->shape(), computation_layout.result_layout().shape()));
245   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
246       ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
247 }
248 
TEST_F(LayoutAssignmentTest,TupleSelect)249 TEST_F(LayoutAssignmentTest, TupleSelect) {
250   // Verify layouts of a select with tuple operands is assigned properly.
251   auto builder = HloComputation::Builder(TestName());
252   auto constant0 = builder.AddInstruction(
253       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
254           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
255   auto constant1 = builder.AddInstruction(
256       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
257           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
258   auto tuple0 = builder.AddInstruction(
259       HloInstruction::CreateTuple({constant0, constant1}));
260   auto tuple1 = builder.AddInstruction(
261       HloInstruction::CreateTuple({constant0, constant1}));
262 
263   auto pred = builder.AddInstruction(
264       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
265 
266   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
267       tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
268 
269   auto m = CreateNewVerifiedModule();
270   m->AddEntryComputation(builder.Build());
271 
272   ComputationLayout computation_layout(
273       m->entry_computation()->ComputeProgramShape());
274   Shape result_shape =
275       ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
276   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
277       result_shape));
278 
279   AssignLayouts(m.get(), &computation_layout);
280 
281   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
282 }
283 
TEST_F(LayoutAssignmentTest,ConflictingLayoutTuple)284 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
285   // Construct following computation which has conflicting layouts for two
286   // elements of a tuple which share the same source logicalb buffer:
287   //
288   // %constant = Constant(...)
289   // %inner_tuple = Tuple(%constant)
290   // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
291   //
292   // Result layout col-major for the first element and row-major for the
293   // second. This results in the conflict where the element of the inner_tuple
294   // needs to be both col and row major. This is resolved by deep-copying the
295   // tuple and assigning the layouts of the copied arrays as needed.
296   auto builder = HloComputation::Builder(TestName());
297   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
298       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
299   auto inner_tuple =
300       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
301   auto nested_tuple = builder.AddInstruction(
302       HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
303 
304   auto m = CreateNewVerifiedModule();
305   m->AddEntryComputation(builder.Build());
306 
307   ComputationLayout computation_layout(
308       m->entry_computation()->ComputeProgramShape());
309   Shape result_shape = nested_tuple->shape();
310   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
311       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
312   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
313       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
314   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
315       result_shape));
316 
317   LayoutAssignment layout_assignment(&computation_layout);
318   AssignLayouts(m.get(), &computation_layout);
319 
320   // Layout assignment should have deep copied the result of the computation to
321   // address the layout conflict. This results in several Tuple() and
322   // GetTupleElement() instructions. Running algebraic simplification should
323   // clean up the code to something like:
324   //
325   //  %constant = Constant(...) layout={1,0}
326   //  %tuple.0 = Tuple(%constant) layout=({1,0})
327   //  %copy = Copy(%constant) layout={0,1}  # layout transposed
328   //  %tuple.1 = Tuple(%copy) layout=({0,1})
329   //  %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
330   //
331   AlgebraicSimplifierOptions options(
332       [](const Shape&, const Shape&) { return false; });
333   options.set_is_layout_sensitive(true);
334   EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
335   HloInstruction* root = m->entry_computation()->root_instruction();
336   // Verify layout of the root and the root's operands.
337   EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
338   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
339                                root->operand(0)->shape()));
340   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
341                                root->operand(1)->shape()));
342 
343   // Verify the structure of the HLO graph.
344   EXPECT_THAT(root,
345               GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
346                                   m::Tuple(m::Copy(m::Op().Is(constant))))));
347 }
348 
TEST_F(LayoutAssignmentTest,ElementwiseAndReshape)349 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
350   // param -> log -> reshape -> tanh
351   auto builder = HloComputation::Builder(TestName());
352   Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
353   Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
354   auto param = builder.AddInstruction(
355       HloInstruction::CreateParameter(0, ashape, "param"));
356   auto log = builder.AddInstruction(
357       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
358   auto reshape =
359       builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
360   auto tanh = builder.AddInstruction(
361       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
362 
363   auto m = CreateNewVerifiedModule();
364   HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
365 
366   Shape ashape_with_layout(ashape);
367   Shape bshape_with_layout(bshape);
368   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
369   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
370 
371   ComputationLayout computation_layout(computation->ComputeProgramShape());
372   *computation_layout.mutable_parameter_layout(0) =
373       ShapeLayout(ashape_with_layout);
374   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
375   AssignLayouts(m.get(), &computation_layout);
376 
377   auto log_minor_to_major =
378       AsInt64Slice(log->shape().layout().minor_to_major());
379   EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
380             PositionInContainer(log_minor_to_major, 2));
381 
382   auto reshape_minor_to_major =
383       AsInt64Slice(reshape->shape().layout().minor_to_major());
384   EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
385             PositionInContainer(reshape_minor_to_major, 2));
386 }
387 
388 // Test whether LayoutAssignment assigns layouts to elementwise operations to
389 // keep linear indices valid across them, and to transpositions to make them
390 // bitcasts.
TEST_F(LayoutAssignmentTest,ElementwiseAndTranspose)391 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
392   // param -> log -> transpose -> tanh
393   auto builder = HloComputation::Builder(TestName());
394   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
395   Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
396   auto param = builder.AddInstruction(
397       HloInstruction::CreateParameter(0, ashape, "param"));
398   auto log = builder.AddInstruction(
399       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
400   auto transpose = builder.AddInstruction(
401       HloInstruction::CreateTranspose(bshape, log, {1, 0}));
402   auto tanh = builder.AddInstruction(
403       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
404   auto m = CreateNewVerifiedModule();
405   auto computation = m->AddEntryComputation(builder.Build(tanh));
406 
407   Shape ashape_with_layout(ashape);
408   Shape bshape_with_layout(bshape);
409   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
410   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
411 
412   ComputationLayout computation_layout(computation->ComputeProgramShape());
413   *computation_layout.mutable_parameter_layout(0) =
414       ShapeLayout(ashape_with_layout);
415   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
416   AssignLayouts(m.get(), &computation_layout);
417 
418   EXPECT_TRUE(
419       LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
420   EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
421                                 transpose->shape().layout()));
422   EXPECT_TRUE(
423       LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
424 }
425 
426 // Test whether LayoutAssignment assigns layouts to transpositions to make them
427 // bitcasts.
TEST_F(LayoutAssignmentTest,BroadcastAndTranspose)428 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
429   // param -> broadcast -> transpose
430   auto builder = HloComputation::Builder(TestName());
431   Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
432   Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
433   Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
434   auto param = builder.AddInstruction(
435       HloInstruction::CreateParameter(0, ashape, "param"));
436   auto broadcast = builder.AddInstruction(
437       HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
438   auto transpose = builder.AddInstruction(
439       HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
440   auto m = CreateNewVerifiedModule();
441   HloComputation* computation =
442       m->AddEntryComputation(builder.Build(transpose));
443 
444   Shape input_shape_with_layout(ashape);
445   Shape output_shape_with_layout(cshape);
446   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
447   *output_shape_with_layout.mutable_layout() =
448       LayoutUtil::MakeLayout({2, 1, 0});
449 
450   ComputationLayout computation_layout(computation->ComputeProgramShape());
451   *computation_layout.mutable_parameter_layout(0) =
452       ShapeLayout(input_shape_with_layout);
453   *computation_layout.mutable_result_layout() =
454       ShapeLayout(output_shape_with_layout);
455   AssignLayouts(m.get(), &computation_layout);
456 
457   EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
458               ElementsAre(0, 1, 2));
459 }
460 
TEST_F(LayoutAssignmentTest,ReshapeOperandHasMultipleUsers)461 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
462   // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
463   //                            \                                     /
464   //                             \-> tanh[3x4] -> broadcast2[2x3x4] -/
465   //
466   // The layout of `transpose` is set to {1,0} because it provides a buffer to
467   // the computation result which has a fixed layout.. Therefore, `broadcast`
468   // (the operand of transpose) is expected to have layout {0,1} so that the
469   // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
470   // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
471   Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
472   Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
473   Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
474   Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
475 
476   auto builder = HloComputation::Builder(TestName());
477   auto param = builder.AddInstruction(
478       HloInstruction::CreateParameter(0, f32_4, "param"));
479   auto broadcast = builder.AddInstruction(
480       HloInstruction::CreateBroadcast(f32_34, param, {1}));
481   auto transpose = builder.AddInstruction(
482       HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
483   auto tanh = builder.AddInstruction(
484       HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
485   auto broadcast2 = builder.AddInstruction(
486       HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
487   auto tuple = builder.AddInstruction(
488       HloInstruction::CreateTuple({transpose, broadcast2}));
489   auto m = CreateNewVerifiedModule();
490   HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
491 
492   ComputationLayout computation_layout(computation->ComputeProgramShape());
493   Shape param_shape_with_layout(f32_4);
494   Shape transpose_shape_with_layout(f32_43);
495   Shape broadcast2_shape_with_layout(f32_234);
496   *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
497   *transpose_shape_with_layout.mutable_layout() =
498       LayoutUtil::MakeLayout({1, 0});
499   *broadcast2_shape_with_layout.mutable_layout() =
500       LayoutUtil::MakeLayout({2, 1, 0});
501 
502   *computation_layout.mutable_parameter_layout(0) =
503       ShapeLayout(param_shape_with_layout);
504   *computation_layout.mutable_result_layout() =
505       ShapeLayout(ShapeUtil::MakeTupleShape(
506           {transpose_shape_with_layout, broadcast2_shape_with_layout}));
507   AssignLayouts(m.get(), &computation_layout);
508 
509   EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
510   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
511   EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
512 }
513 
514 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
515  public:
OperandsMustBeTheSameLayoutAssignment(ComputationLayout * entry_computation_layout)516   explicit OperandsMustBeTheSameLayoutAssignment(
517       ComputationLayout* entry_computation_layout)
518       : LayoutAssignment(entry_computation_layout) {}
519 
520  protected:
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)521   Status PropagateBufferConstraint(
522       const BufferLayoutConstraint& buffer_constraint,
523       LayoutConstraints* constraints) override {
524     const LogicalBuffer& buffer = buffer_constraint.buffer();
525     const HloInstruction* instruction = buffer.instruction();
526 
527     // Force the operands' layout to the output layout.
528     for (int64 operand_no = 0; operand_no < instruction->operand_count();
529          ++operand_no) {
530       const HloInstruction* operand = instruction->operand(operand_no);
531       if (instruction->shape().rank() != operand->shape().rank()) {
532         continue;
533       }
534       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
535           buffer_constraint.layout(), instruction, operand_no,
536           /*mandatory=*/true));
537     }
538     return PropagateBufferConstraintToUses(buffer_constraint, constraints);
539   }
540 };
541 
TEST_F(LayoutAssignmentTest,MakeOperandsTheSame)542 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
543   // param0 -> concatenate -> reshape
544   // param1   -^
545   auto builder = HloComputation::Builder(TestName());
546   Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
547   Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
548   Shape cshape = ShapeUtil::MakeShape(F32, {100});
549   auto param0 = builder.AddInstruction(
550       HloInstruction::CreateParameter(0, ashape, "param"));
551   auto param1 = builder.AddInstruction(
552       HloInstruction::CreateParameter(1, ashape, "param"));
553   auto concatenate = builder.AddInstruction(
554       HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
555   auto reshape = builder.AddInstruction(
556       HloInstruction::CreateReshape(cshape, concatenate));
557   auto m = CreateNewVerifiedModule();
558   HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
559 
560   Shape param0_shape_with_layout(ashape);
561   Shape param1_shape_with_layout(ashape);
562   *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
563   *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
564 
565   ComputationLayout computation_layout(computation->ComputeProgramShape());
566   *computation_layout.mutable_parameter_layout(0) =
567       ShapeLayout(param0_shape_with_layout);
568   *computation_layout.mutable_parameter_layout(1) =
569       ShapeLayout(param1_shape_with_layout);
570   OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
571   EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
572 
573   EXPECT_EQ(concatenate->operand(0)->shape().layout().minor_to_major(),
574             concatenate->operand(1)->shape().layout().minor_to_major());
575   EXPECT_EQ(concatenate->shape().layout().minor_to_major(),
576             concatenate->operand(1)->shape().layout().minor_to_major());
577 }
578 
579 // Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest,TransposeToBitcastFromOperand)580 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
581   auto builder = HloComputation::Builder(TestName());
582   Shape input_shape_with_layout =
583       ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
584   auto param = builder.AddInstruction(
585       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
586   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
587       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
588   auto m = CreateNewVerifiedModule();
589   HloComputation* computation =
590       m->AddEntryComputation(builder.Build(transpose));
591   ComputationLayout computation_layout(computation->ComputeProgramShape());
592   AssignLayouts(m.get(), &computation_layout);
593   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
594                                             transpose->shape(), {2, 3, 0, 1}));
595 }
596 // Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest,TransposeToBitcastToUser)597 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
598   auto builder = HloComputation::Builder(TestName());
599   Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
600   auto constant = builder.AddInstruction(
601       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
602   auto broadcast = builder.AddInstruction(
603       HloInstruction::CreateBroadcast(input_shape, constant, {}));
604   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
605       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
606   auto m = CreateNewVerifiedModule();
607   HloComputation* computation =
608       m->AddEntryComputation(builder.Build(transpose));
609   ComputationLayout computation_layout(computation->ComputeProgramShape());
610   AssignLayouts(m.get(), &computation_layout);
611   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
612                                             transpose->shape(), {2, 3, 0, 1}));
613 }
614 
615 // TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,TransposeIsBitcastFail)616 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
617   auto builder = HloComputation::Builder(TestName());
618   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
619   Shape input_shape_with_layout(input_shape);
620   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
621   auto param = builder.AddInstruction(
622       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
623   auto hlo = builder.AddInstruction(
624       HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
625   // Clear the default layout assigned to the instruction.
626   LayoutUtil::ClearLayout(hlo->mutable_shape());
627   EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
628                                              hlo->shape(), hlo->dimensions()),
629                "LayoutUtil::HasLayout");
630 }
631 
632 // ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,ReshapeIsBitcastFail)633 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
634   auto builder = HloComputation::Builder(TestName());
635   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
636   Shape input_shape_with_layout(input_shape);
637   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
638   auto param = builder.AddInstruction(
639       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
640   auto hlo =
641       builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
642   // Clear the default layout assigned to the instruction.
643   LayoutUtil::ClearLayout(hlo->mutable_shape());
644   EXPECT_DEATH(
645       ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
646       "LayoutUtil::HasLayout");
647 }
648 
649 // Check that the computation below doesn't crash the compiler.
650 //
651 // Within a fusion computation, only the parameters and result get assigned a
652 // layout.  When we run the algebraic simplifier on this computation post layout
653 // assignment, it should not call TransposeIsBitcast on the `transpose` node
654 // inside the fusion computation as TransposeIsBitcast checks both input_shape
655 // and output_shape have layouts.
TEST_F(LayoutAssignmentTest,TransposeWithinFusionDoesNotCrash)656 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
657   const char* module_str = R"(
658     HloModule test_module
659 
660     fused_computation {
661       param_1 = f32[2,2,2]{2,1,0} parameter(1)
662       transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
663       reduce_1 = f32[] parameter(0)
664       broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
665       ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
666     }
667 
668     ENTRY entry_computation {
669       fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
670       reduce.1 = f32[] parameter(0)
671       fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
672      ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
673     }
674   )";
675 
676   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
677                           ParseAndReturnVerifiedModule(module_str));
678   std::unique_ptr<HloModule> compiled_module =
679       backend()
680           .compiler()
681           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
682                          /*device_allocator=*/nullptr)
683           .ConsumeValueOrDie();
684 
685   EXPECT_EQ(Status::OK(), backend()
686                               .compiler()
687                               ->RunBackend(std::move(compiled_module),
688                                            backend().default_stream_executor(),
689                                            /*device_allocator=*/nullptr)
690                               .status());
691 }
692 
693 // A GTE inside of a fusion node inherits the layout of its operand (which
694 // should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest,GTEInheritsLayoutFromOperand)695 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
696   const char* module_str = R"(
697     HloModule test_module
698 
699     fused_computation {
700       fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
701       gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
702       gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
703       gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
704       gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
705       add = f32[2,2,2] add(gte1a, gte1b)
706       ROOT fresult = f32[2,2,2] add(gte0, add)
707     }
708 
709     ENTRY entry_computation {
710       param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
711       ROOT fusion =
712         f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
713     }
714   )";
715 
716   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
717                           ParseAndReturnVerifiedModule(module_str));
718   ComputationLayout computation_layout(
719       m->entry_computation()->ComputeProgramShape());
720   Shape param_shape = ShapeUtil::MakeTupleShape(
721       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
722        ShapeUtil::MakeTupleShape({
723            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
724            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
725        })});
726   TF_ASSERT_OK(
727       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
728           param_shape));
729   computation_layout.mutable_result_layout()->ResetLayout(
730       LayoutUtil::MakeLayout({2, 1, 0}));
731   AssignLayouts(m.get(), &computation_layout);
732 
733   EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
734   EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
735   EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
736   EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
737   EXPECT_THAT(FindInstruction(m.get(), "gte1")
738                   ->shape()
739                   .tuple_shapes(0)
740                   .layout()
741                   .minor_to_major(),
742               ElementsAre(1, 2, 0));
743   EXPECT_THAT(FindInstruction(m.get(), "gte1")
744                   ->shape()
745                   .tuple_shapes(1)
746                   .layout()
747                   .minor_to_major(),
748               ElementsAre(2, 0, 1));
749 }
750 
TEST_F(LayoutAssignmentTest,ConditionalAsymmetricLayout)751 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
752   auto builder = HloComputation::Builder(TestName());
753   auto m = CreateNewVerifiedModule();
754   Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
755   Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
756   Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
757 
758   auto param0 = builder.AddInstruction(
759       HloInstruction::CreateParameter(0, shape, "param0"));
760   auto param1 = builder.AddInstruction(
761       HloInstruction::CreateParameter(1, shape, "param1"));
762   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
763       2, ShapeUtil::MakeShape(PRED, {}), "param2"));
764   auto tuple =
765       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
766 
767   auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
768   {
769     auto param = true_builder.AddInstruction(
770         HloInstruction::CreateParameter(0, tshape, "param"));
771     auto gte0 = true_builder.AddInstruction(
772         HloInstruction::CreateGetTupleElement(shape, param, 0));
773     auto gte1 = true_builder.AddInstruction(
774         HloInstruction::CreateGetTupleElement(shape, param, 1));
775     auto add = true_builder.AddInstruction(
776         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
777     true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
778   }
779   HloComputation* true_computation =
780       m->AddEmbeddedComputation(true_builder.Build());
781 
782   auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
783   {
784     Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
785     false_builder.AddInstruction(
786         HloInstruction::CreateParameter(0, tshape, "param"));
787     // Using infeed as layout assignment does not mess up with it.
788     auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
789     auto infeed = false_builder.AddInstruction(
790         HloInstruction::CreateInfeed(xshape, token, ""));
791     auto infeed_data = false_builder.AddInstruction(
792         HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
793     false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
794   }
795   HloComputation* false_computation =
796       m->AddEmbeddedComputation(false_builder.Build());
797   builder.AddInstruction(HloInstruction::CreateConditional(
798       result_tshape, pred, tuple, true_computation, tuple, false_computation));
799 
800   HloComputation* computation = m->AddEntryComputation(builder.Build());
801   ComputationLayout computation_layout(computation->ComputeProgramShape());
802 
803   AssignLayouts(m.get(), &computation_layout);
804 
805   const HloInstruction* true_root = true_computation->root_instruction();
806   const HloInstruction* false_root = false_computation->root_instruction();
807   EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
808   EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
809 
810   const HloInstruction* true_result = true_root->operand(0);
811   const HloInstruction* false_result = false_root->operand(0);
812   EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
813                                 false_result->shape().layout()));
814   EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
815 }
816 
TEST_F(LayoutAssignmentTest,InternalErrorOnBitcast)817 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
818   auto builder = HloComputation::Builder(TestName());
819   auto constant0 = builder.AddInstruction(
820       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
821           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
822   builder.AddInstruction(
823       HloInstruction::CreateBitcast(constant0->shape(), constant0));
824   auto m = CreateNewVerifiedModule();
825   m->AddEntryComputation(builder.Build());
826 
827   ComputationLayout computation_layout(
828       m->entry_computation()->ComputeProgramShape());
829   LayoutAssignment layout_assignment(&computation_layout);
830   Status error_status = layout_assignment.Run(m.get()).status();
831   EXPECT_FALSE(error_status.ok());
832   EXPECT_THAT(
833       error_status.error_message(),
834       ::testing::HasSubstr(
835           "Unexpected bitcast operation seen during layout assignment"));
836 }
837 
TEST_F(LayoutAssignmentTest,ChannelLayoutMismatch)838 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
839   // Pin non matching layouts to parameter and root.
840   const char* module_str = R"(
841     HloModule test_module
842 
843     ENTRY entry_computation {
844       param = (f32[2,2]) parameter(0)
845       gte = f32[2,2] get-tuple-element(param), index=0
846       token0 = token[] after-all()
847       recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
848       recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
849         sharding={maximal device=1}
850       ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
851       send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
852         sharding={maximal device=0}
853       send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
854     }
855   )";
856 
857   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
858                           ParseAndReturnVerifiedModule(module_str));
859   ComputationLayout computation_layout(
860       m->entry_computation()->ComputeProgramShape());
861   Shape param_shape = ShapeUtil::MakeTupleShape(
862       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
863   TF_ASSERT_OK(
864       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
865           param_shape));
866   computation_layout.mutable_result_layout()->ResetLayout(
867       LayoutUtil::MakeLayout({1, 0}));
868 
869   ChannelLayoutConstraints channel_constraints;
870   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
871 
872   EXPECT_TRUE(ShapeUtil::Equal(FindInstruction(m.get(), "send")->shape(),
873                                FindInstruction(m.get(), "recv")->shape()));
874 }
875 
TEST_F(LayoutAssignmentTest,AllReduceLayoutMissmatch)876 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
877   // Pin non matching layouts to parameter and root.
878   const char* module_str = R"(
879     HloModule test_module
880 
881     add {
882       lhs = f32[] parameter(0)
883       rhs = f32[] parameter(1)
884       ROOT add = f32[] add(lhs, rhs)
885     }
886 
887     ENTRY entry_computation {
888       param = (f32[2,2]) parameter(0)
889       gte = f32[2,2] get-tuple-element(param), index=0
890       ar.0 = f32[2,2] all-reduce(gte),
891         channel_id=1, replica_groups={{0}}, to_apply=add,
892         sharding={maximal device=0}
893       const = f32[2,2] constant({{0,1},{2,3}})
894       ROOT ar.1 = f32[2,2] all-reduce(const),
895         channel_id=1, replica_groups={{0}}, to_apply=add,
896         sharding={maximal device=1}
897     })";
898   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
899                           ParseAndReturnVerifiedModule(module_str));
900   ComputationLayout computation_layout(
901       m->entry_computation()->ComputeProgramShape());
902   Shape param_shape = ShapeUtil::MakeTupleShape(
903       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
904   TF_ASSERT_OK(
905       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
906           param_shape));
907   computation_layout.mutable_result_layout()->ResetLayout(
908       LayoutUtil::MakeLayout({1, 0}));
909 
910   ChannelLayoutConstraints channel_constraints;
911   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
912 
913   EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
914   EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
915   EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
916   const HloInstruction* root = m->entry_computation()->root_instruction();
917   EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
918 }
919 
TEST_F(LayoutAssignmentTest,CopySliceOperandToAvoidImplicitLayoutChange)920 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
921   const char* module_str = R"(
922     HloModule CopySliceOperandToAvoidImplicitLayoutChange
923 
924     ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
925       par0 = f32[3,4]{1,0} parameter(0)
926       par1 = f32[4,5]{0,1} parameter(1)
927       slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
928       ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
929     }
930   )";
931 
932   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
933                           ParseAndReturnVerifiedModule(module_str));
934   auto compiled_module =
935       backend()
936           .compiler()
937           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
938                          /*device_allocator=*/nullptr)
939           .ConsumeValueOrDie();
940   HloInstruction* root =
941       compiled_module->entry_computation()->root_instruction();
942   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
943   EXPECT_THAT(
944       root,
945       GmockMatch(m::Add(
946           m::Parameter(),
947           m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
948 }
949 
TEST_F(LayoutAssignmentTest,CopyDSliceOperandToAvoidImplicitLayoutChange)950 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
951   const char* module_str = R"(
952     HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
953 
954     ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
955       par0 = f32[3,4]{1,0} parameter(0)
956       par1 = f32[4,5]{0,1} parameter(1)
957       par2 = s32[] parameter(2)
958       par3 = s32[] parameter(3)
959       dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
960       ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
961     }
962   )";
963 
964   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
965                           ParseAndReturnVerifiedModule(module_str));
966   auto compiled_module =
967       backend()
968           .compiler()
969           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
970                          /*device_allocator=*/nullptr)
971           .ConsumeValueOrDie();
972   HloInstruction* root =
973       compiled_module->entry_computation()->root_instruction();
974   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
975   EXPECT_THAT(root,
976               GmockMatch(m::Add(
977                   m::Parameter(),
978                   m::DynamicSlice(
979                       m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
980                       m::Parameter(2), m::Parameter(3)))));
981 }
982 
TEST_F(LayoutAssignmentTest,CopyConcatOperandToAvoidImplicitLayoutChange)983 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
984   const char* module_str = R"(
985     HloModule CopyConcatOperandToAvoidImplicitLayoutChange
986 
987     ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
988       par0 = f32[3,8]{1,0} parameter(0)
989       par1 = f32[3,5]{0,1} parameter(1)
990       par2 = f32[3,3]{1,0} parameter(2)
991       concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
992         dimensions={1}
993       ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
994     }
995   )";
996 
997   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
998                           ParseAndReturnVerifiedModule(module_str));
999   auto compiled_module =
1000       backend()
1001           .compiler()
1002           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1003                          /*device_allocator=*/nullptr)
1004           .ConsumeValueOrDie();
1005   HloInstruction* root =
1006       compiled_module->entry_computation()->root_instruction();
1007   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
1008   EXPECT_THAT(
1009       root,
1010       GmockMatch(m::Add(
1011           m::Parameter(),
1012           m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
1013                          m::Parameter(2)))));
1014 }
1015 
TEST_F(LayoutAssignmentTest,ConvolutionOperandWithImplicitLayoutChangeNotCopied)1016 TEST_F(LayoutAssignmentTest,
1017        ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
1018   const char* module_str = R"(
1019     HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
1020 
1021     ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
1022       par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
1023       par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
1024       ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
1025         window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
1026         feature_group_count=1
1027     }
1028   )";
1029 
1030   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1031                           ParseAndReturnVerifiedModule(module_str));
1032   auto compiled_module =
1033       backend()
1034           .compiler()
1035           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1036                          /*device_allocator=*/nullptr)
1037           .ConsumeValueOrDie();
1038   HloInstruction* root =
1039       compiled_module->entry_computation()->root_instruction();
1040   EXPECT_THAT(root,
1041               GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
1042 }
1043 
TEST_F(LayoutAssignmentTest,PropagatingLayoutFromResultToOperand)1044 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
1045   const char* module_str = R"(
1046     HloModule PropagatingLayoutFromResultToOperand
1047 
1048     ENTRY PropagatingLayoutFromResultToOperand {
1049       par0 = f32[4,5]{1,0} parameter(0)
1050       ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
1051     }
1052   )";
1053 
1054   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1055                           ParseAndReturnVerifiedModule(module_str));
1056   auto compiled_module =
1057       backend()
1058           .compiler()
1059           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1060                          /*device_allocator=*/nullptr)
1061           .ConsumeValueOrDie();
1062   HloInstruction* root =
1063       compiled_module->entry_computation()->root_instruction();
1064   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
1065   EXPECT_THAT(root,
1066               GmockMatch(m::Slice(
1067                   m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
1068 }
1069 
TEST_F(LayoutAssignmentTest,TupleCopyOnLayoutMismatch)1070 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
1071   // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
1072   // The mismatch forces a copy of the tuple.  The tuple contains a token, so
1073   // layout assignment will fail if it tries to copy the whole tuple.
1074   const char* module_str = R"(
1075     HloModule TupleCopyOnLayoutMismatch
1076 
1077     condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
1078       tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1079       counter.1 = s32[] get-tuple-element(tup.1), index=0
1080       five = s32[] constant(5)
1081       ROOT lt = pred[] compare(counter.1, five), direction=LT
1082     }
1083 
1084     body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
1085       tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1086       counter.2 = s32[] get-tuple-element(tup.2), index=0
1087       tok.2 = token[] get-tuple-element(tup.2), index=1
1088 
1089       ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
1090       next_tok = token[] get-tuple-element(ifeed.2), index=1
1091       next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
1092 
1093       one = s32[] constant(1)
1094       next_counter = s32[] add(counter.2, one)
1095       ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
1096     }
1097 
1098     ENTRY main () -> f32[512,1024]{0,1} {
1099       start_tok = token[] after-all()
1100 
1101       ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
1102       itok = token[] get-tuple-element(ifeed.3), index=1
1103       ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
1104 
1105       zero = s32[] constant(0)
1106       itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
1107 
1108       loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
1109       ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
1110     }
1111   )";
1112 
1113   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1114                           ParseAndReturnVerifiedModule(module_str));
1115   ComputationLayout computation_layout(
1116       m->entry_computation()->ComputeProgramShape());
1117 
1118   // Sanity check to verify that there's a layout mismatch.
1119   EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1120   EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1121 
1122   AssignLayouts(m.get(), &computation_layout);
1123 
1124   // Make sure that layout assignment did not magically eliminate the mismatch,
1125   // in which case the test didn't prove anything.
1126   EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1127   EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1128 }
1129 
TEST_F(LayoutAssignmentTest,CustomCallNotLayoutConstrained)1130 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
1131   const char* module_str = R"(
1132 HloModule CustomCallNotLayoutConstrained
1133 
1134 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
1135   %p = f32[42,2,3] parameter(0)
1136   ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
1137 }
1138 )";
1139   // Try with a couple different layouts. In each case the custom calls operand
1140   // and result layout should match that of the computation.
1141   {
1142     TF_ASSERT_OK_AND_ASSIGN(
1143         std::unique_ptr<VerifiedHloModule> m,
1144         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1145     ComputationLayout computation_layout = m->entry_computation_layout();
1146     *computation_layout.mutable_parameter_layout(0) =
1147         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
1148     *computation_layout.mutable_result_layout() = ShapeLayout(
1149         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
1150     AssignLayouts(m.get(), &computation_layout);
1151 
1152     HloInstruction* root = m->entry_computation()->root_instruction();
1153     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1154     ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
1155     ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
1156   }
1157   {
1158     TF_ASSERT_OK_AND_ASSIGN(
1159         std::unique_ptr<VerifiedHloModule> m,
1160         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1161     ComputationLayout computation_layout = m->entry_computation_layout();
1162     *computation_layout.mutable_parameter_layout(0) =
1163         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
1164     *computation_layout.mutable_result_layout() = ShapeLayout(
1165         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
1166     AssignLayouts(m.get(), &computation_layout);
1167 
1168     HloInstruction* root = m->entry_computation()->root_instruction();
1169     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1170     ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
1171     ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
1172   }
1173 }
1174 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrained)1175 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
1176   const char* module_str = R"(
1177 HloModule CustomCallLayoutConstrained
1178 
1179 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1180   %p0 = f32[4,4] parameter(0)
1181   %p1 = f32[2,3] parameter(1)
1182   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
1183 }
1184 )";
1185   TF_ASSERT_OK_AND_ASSIGN(
1186       std::unique_ptr<VerifiedHloModule> m,
1187       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1188   ComputationLayout computation_layout = m->entry_computation_layout();
1189   *computation_layout.mutable_parameter_layout(0) =
1190       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1191   *computation_layout.mutable_parameter_layout(1) =
1192       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1193   *computation_layout.mutable_result_layout() = ShapeLayout(
1194       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1195   AssignLayouts(m.get(), &computation_layout);
1196 
1197   // The custom call should be partially encapsulated in kCopy instructions
1198   // because of the layout mismatches.
1199   ASSERT_THAT(m->entry_computation()->root_instruction(),
1200               GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
1201 
1202   const HloInstruction* custom_call =
1203       m->entry_computation()->root_instruction()->operand(0);
1204   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1205   ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
1206   ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1207 }
1208 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedZeroOperands)1209 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
1210   const char* module_str = R"(
1211 HloModule CustomCallLayoutConstrainedZeroOperands
1212 
1213 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
1214   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1215 }
1216 )";
1217   TF_ASSERT_OK_AND_ASSIGN(
1218       std::unique_ptr<VerifiedHloModule> m,
1219       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1220   ComputationLayout computation_layout = m->entry_computation_layout();
1221   *computation_layout.mutable_result_layout() = ShapeLayout(
1222       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1223   AssignLayouts(m.get(), &computation_layout);
1224 
1225   ASSERT_THAT(m->entry_computation()->root_instruction(),
1226               GmockMatch(m::Copy(m::CustomCall())));
1227 
1228   const HloInstruction* custom_call =
1229       m->entry_computation()->root_instruction()->operand(0);
1230   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1231 }
1232 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleOperand)1233 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
1234   const char* module_str = R"(
1235 HloModule CustomCallLayoutConstrainedTupleOperand
1236 
1237 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1238   %p0 = f32[4,4] parameter(0)
1239   %p1 = f32[2,3] parameter(1)
1240   %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
1241   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
1242 }
1243 )";
1244   TF_ASSERT_OK_AND_ASSIGN(
1245       std::unique_ptr<VerifiedHloModule> m,
1246       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1247   ComputationLayout computation_layout = m->entry_computation_layout();
1248   *computation_layout.mutable_parameter_layout(0) =
1249       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1250   *computation_layout.mutable_parameter_layout(1) =
1251       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1252   *computation_layout.mutable_result_layout() = ShapeLayout(
1253       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1254   AssignLayouts(m.get(), &computation_layout);
1255 
1256   HloInstruction* root = m->entry_computation()->root_instruction();
1257   ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
1258 
1259   ASSERT_THAT(m->entry_computation()->root_instruction(),
1260               GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
1261 
1262   const HloInstruction* custom_call =
1263       m->entry_computation()->root_instruction()->operand(0);
1264   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1265   ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
1266 }
1267 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleResult)1268 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
1269   const char* module_str = R"(
1270 HloModule CustomCallLayoutConstrainedTupleResult
1271 
1272 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
1273   %p0 = f32[4,4] parameter(0)
1274   ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
1275 }
1276 )";
1277   // Try with a couple different layouts. In each case the custom calls operand
1278   // and result layout should match that of the computation.
1279   TF_ASSERT_OK_AND_ASSIGN(
1280       std::unique_ptr<VerifiedHloModule> m,
1281       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1282   ComputationLayout computation_layout = m->entry_computation_layout();
1283   *computation_layout.mutable_parameter_layout(0) =
1284       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1285   *computation_layout.mutable_result_layout() =
1286       ShapeLayout(ShapeUtil::MakeTupleShape(
1287           {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
1288            ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
1289   AssignLayouts(m.get(), &computation_layout);
1290 
1291   ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
1292 
1293   const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
1294   ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
1295 }
1296 
AssignLayoutsToComputation(HloModule * m,ChannelLayoutConstraints * channel_constraints=nullptr)1297 Status AssignLayoutsToComputation(
1298     HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
1299   if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
1300     m->mutable_entry_computation_layout()
1301         ->mutable_result_layout()
1302         ->SetToDefaultLayout();
1303   }
1304   LayoutAssignment layout_assignment(
1305       m->mutable_entry_computation_layout(),
1306       LayoutAssignment::InstructionCanChangeLayout, channel_constraints);
1307   return layout_assignment.Run(m).status();
1308 }
1309 
TEST_F(LayoutAssignmentTest,OverwriteDiamondShapedConstraintsX)1310 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
1311   // Check that we handle a diamond-shaped graph correctly.
1312   //      transpose
1313   //       /    \
1314   //     add    |
1315   //       \    /
1316   //        tuple
1317 
1318   auto b = HloComputation::Builder(TestName());
1319   Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
1320   Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
1321   auto param0 =
1322       b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
1323   auto param1 =
1324       b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
1325   auto transpose =
1326       b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
1327   auto add = b.AddInstruction(
1328       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
1329   b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
1330   auto m = CreateNewVerifiedModule();
1331   m->AddEntryComputation(b.Build());
1332   Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
1333   Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
1334   *m->mutable_entry_computation_layout()->mutable_result_layout() =
1335       ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
1336   const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
1337   ForceParameterLayout(m.get(), 0, r2_dim0major);
1338   ForceParameterLayout(m.get(), 1, r2_dim0major);
1339   TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
1340 
1341   EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
1342   EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
1343               ElementsAre(1, 0));
1344   EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
1345               ElementsAre(1, 0));
1346 
1347   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
1348 }
1349 
1350 // Tests that the layout assignment supports layout-constrained all-reduce with
1351 // different operand layouts (b/146056839).
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllReduce)1352 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllReduce) {
1353   const char* module_str = R"(
1354 HloModule test_module
1355 
1356 add {
1357   lhs = f32[] parameter(0)
1358   rhs = f32[] parameter(1)
1359   ROOT add = f32[] add(lhs, rhs)
1360 }
1361 
1362 ENTRY entry_computation {
1363   param = (f32[8,4]{0,1}, f32[16,2]{0,1}) parameter(0)
1364   gte0 = f32[8,4] get-tuple-element(param), index=0
1365   gte1 = f32[16,2] get-tuple-element(param), index=1
1366   crs = (f32[8,4]{0,1}, f32[16,2]{1,0}) all-reduce(gte0, gte1),
1367     replica_groups={}, constrain_layout=true, to_apply=add
1368   gte2 = f32[8,4] get-tuple-element(crs), index=0
1369   gte3 = f32[16,2] get-tuple-element(crs), index=1
1370   ROOT result = (f32[8,4]{1,0}, f32[16,2]{1,0}) tuple(gte2, gte3)
1371 }
1372 )";
1373 
1374   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1375                           ParseAndReturnVerifiedModule(module_str));
1376   ComputationLayout computation_layout(
1377       m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1378 
1379   ChannelLayoutConstraints channel_constraints;
1380   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1381 
1382   const HloInstruction* crs = FindInstruction(m.get(), "crs");
1383   ExpectTupleLayoutIs(crs->shape(), {{0, 1}, {1, 0}});
1384   ExpectLayoutIs(crs->operand(0)->shape(), {0, 1});
1385   ExpectLayoutIs(crs->operand(1)->shape(), {1, 0});
1386 }
1387 
TEST_F(LayoutAssignmentTest,LayoutConstrainedAllToAll)1388 TEST_F(LayoutAssignmentTest, LayoutConstrainedAllToAll) {
1389   const char* module_str = R"(
1390 HloModule test_module
1391 
1392 add {
1393   lhs = f32[] parameter(0)
1394   rhs = f32[] parameter(1)
1395   ROOT add = f32[] add(lhs, rhs)
1396 }
1397 
1398 ENTRY entry_computation {
1399   param = (f32[16,4]{0,1}, f32[16,4]{1,0}) parameter(0)
1400   gte0 = f32[16,4] get-tuple-element(param), index=0
1401   gte1 = f32[16,4] get-tuple-element(param), index=1
1402   alltoall = (f32[16,4]{1,0}, f32[16,4]{1,0}) all-reduce(gte0, gte1),
1403     replica_groups={{0,1}}, constrain_layout=true, to_apply=add
1404   gte2 = f32[16,4] get-tuple-element(alltoall), index=0
1405   gte3 = f32[16,4] get-tuple-element(alltoall), index=1
1406   ROOT concat = f32[16,8]{0,1} concatenate(gte2, gte3), dimensions={1}
1407 }
1408 )";
1409 
1410   TF_ASSERT_OK_AND_ASSIGN(
1411       std::unique_ptr<HloModule> m,
1412       ParseAndReturnVerifiedModule(module_str, /*replica_count=*/2));
1413   ComputationLayout computation_layout(
1414       m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1415 
1416   ChannelLayoutConstraints channel_constraints;
1417   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
1418 
1419   const HloInstruction* alltoall = FindInstruction(m.get(), "alltoall");
1420   ExpectTupleLayoutIs(alltoall->shape(), {{1, 0}, {1, 0}});
1421   ExpectLayoutIs(alltoall->operand(0)->shape(), {1, 0});
1422   ExpectLayoutIs(alltoall->operand(1)->shape(), {1, 0});
1423 }
1424 
TEST_F(LayoutAssignmentTest,DynamicRoot)1425 TEST_F(LayoutAssignmentTest, DynamicRoot) {
1426   const char* module_str = R"(
1427 HloModule test_module
1428 
1429 ENTRY entry_computation {
1430   param = f32[1,<=16]{0,1} parameter(0)
1431   ROOT abs = f32[1,<=16]{0,1} abs(param)
1432 }
1433 )";
1434 
1435   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1436                           ParseAndReturnVerifiedModule(module_str));
1437   ComputationLayout computation_layout(
1438       m->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false);
1439   computation_layout.mutable_result_layout()->ClearDynamicShape();
1440 
1441   AssignLayouts(m.get(), &computation_layout);
1442 
1443   const HloInstruction* abs = FindInstruction(m.get(), "abs");
1444   ExpectLayoutIs(abs->operand(0)->shape(), {0, 1});
1445   ExpectLayoutIs(abs->shape(), {0, 1});
1446   EXPECT_TRUE(abs->shape().is_dynamic_dimension(1));
1447 }
1448 
1449 }  // namespace
1450 }  // namespace xla
1451