• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/layout_assignment.h"
17 
18 #include <initializer_list>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
27 #include "tensorflow/compiler/xla/service/computation_layout.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_module.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_parser.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
35 #include "tensorflow/compiler/xla/shape_layout.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/test.h"
38 #include "tensorflow/compiler/xla/test_helpers.h"
39 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
40 #include "tensorflow/compiler/xla/tests/test_utils.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 
46 namespace xla {
47 namespace {
48 
49 namespace m = xla::match;
50 using ::testing::ElementsAre;
51 
52 class LayoutAssignmentTest : public HloTestBase {
53  protected:
AssignLayouts(HloModule * m,ComputationLayout * entry_computation_layout,ChannelLayoutConstraints * channel_constraints=nullptr)54   void AssignLayouts(HloModule* m, ComputationLayout* entry_computation_layout,
55                      ChannelLayoutConstraints* channel_constraints = nullptr) {
56     LayoutAssignment layout_assignment(
57         entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout,
58         /*channel_constraints=*/channel_constraints);
59     EXPECT_IS_OK(layout_assignment.Run(m).status());
60   }
61 
LayoutOf(HloModule * m,absl::string_view name)62   std::vector<int64> LayoutOf(HloModule* m, absl::string_view name) {
63     auto minor_to_major =
64         FindInstruction(m, name)->shape().layout().minor_to_major();
65     return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
66   }
67 
ExpectLayoutIs(const Shape & shape,absl::Span<const int64> minor_to_major)68   void ExpectLayoutIs(const Shape& shape,
69                       absl::Span<const int64> minor_to_major) {
70     const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
71     EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
72         << "Expected layout " << expected << ", actual " << shape.layout();
73   }
74 
ExpectTupleLayoutIs(const Shape & shape,std::initializer_list<absl::Span<const int64>> minor_to_majors)75   void ExpectTupleLayoutIs(
76       const Shape& shape,
77       std::initializer_list<absl::Span<const int64>> minor_to_majors) {
78     int i = 0;
79     for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
80       const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
81       const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
82       EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
83           << "Expected tuple element " << i << " layout " << expected
84           << ", actual " << actual;
85       ++i;
86     }
87   }
88 };
89 
TEST_F(LayoutAssignmentTest,ComputationLayout)90 TEST_F(LayoutAssignmentTest, ComputationLayout) {
91   // Verify the layouts of the root and parameter instructions of a computation
92   // match the ComputationLayout for two different layouts.
93   std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
94   for (auto& minor_to_major : minor_to_majors) {
95     auto builder = HloComputation::Builder(TestName());
96     Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
97     auto param0 = builder.AddInstruction(
98         HloInstruction::CreateParameter(0, ashape, "param0"));
99     auto param1 = builder.AddInstruction(
100         HloInstruction::CreateParameter(1, ashape, "param1"));
101     auto add = builder.AddInstruction(
102         HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
103     auto m = CreateNewVerifiedModule();
104     HloComputation* computation = m->AddEntryComputation(builder.Build());
105 
106     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
107     Shape shape(ashape);
108     *shape.mutable_layout() = layout;
109     const ShapeLayout shape_layout(shape);
110 
111     ComputationLayout computation_layout(computation->ComputeProgramShape());
112     *computation_layout.mutable_parameter_layout(0) = shape_layout;
113     *computation_layout.mutable_parameter_layout(1) = shape_layout;
114     *computation_layout.mutable_result_layout() = shape_layout;
115     AssignLayouts(m.get(), &computation_layout);
116     EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
117     EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
118     EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
119   }
120 }
121 
TEST_F(LayoutAssignmentTest,ComputationLayoutMixedLayout)122 TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
123   // Verify the layouts of the root and parameter instructions of a computation
124   // match the ComputationLayout which has mixed layout.
125   auto builder = HloComputation::Builder(TestName());
126   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
127   auto param0 = builder.AddInstruction(
128       HloInstruction::CreateParameter(0, ashape, "param0"));
129   auto param1 = builder.AddInstruction(
130       HloInstruction::CreateParameter(1, ashape, "param1"));
131   builder.AddInstruction(
132       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
133   auto m = CreateNewVerifiedModule();
134   HloComputation* computation = m->AddEntryComputation(builder.Build());
135 
136   Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
137   Shape col_major_shape(ashape);
138   *col_major_shape.mutable_layout() = col_major_layout;
139   const ShapeLayout col_major(col_major_shape);
140 
141   Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
142   Shape row_major_shape(ashape);
143   *row_major_shape.mutable_layout() = row_major_layout;
144   const ShapeLayout row_major(row_major_shape);
145 
146   ComputationLayout computation_layout(computation->ComputeProgramShape());
147   *computation_layout.mutable_parameter_layout(0) = col_major;
148   *computation_layout.mutable_parameter_layout(1) = row_major;
149   *computation_layout.mutable_result_layout() = col_major;
150 
151   AssignLayouts(m.get(), &computation_layout);
152   EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
153   EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
154   EXPECT_TRUE(LayoutUtil::Equal(
155       col_major_layout, computation->root_instruction()->shape().layout()));
156 }
157 
TEST_F(LayoutAssignmentTest,FusionInstruction)158 TEST_F(LayoutAssignmentTest, FusionInstruction) {
159   // Verify that the layout of the fused parameters in a fusion instruction
160   // match that of the fusion operands. Other fused instructions should have no
161   // layout.
162   std::vector<std::vector<int64>> minor_to_majors = {{0, 1}, {1, 0}};
163   for (auto& minor_to_major : minor_to_majors) {
164     auto builder = HloComputation::Builder(TestName());
165     auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
166         {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
167     auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
168         {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
169     Shape ashape = constant_literal1.shape();
170 
171     auto constant1 = builder.AddInstruction(
172         HloInstruction::CreateConstant(std::move(constant_literal1)));
173     auto constant2 = builder.AddInstruction(
174         HloInstruction::CreateConstant(std::move(constant_literal2)));
175     auto add = builder.AddInstruction(HloInstruction::CreateBinary(
176         ashape, HloOpcode::kAdd, constant1, constant2));
177     auto negate1 = builder.AddInstruction(
178         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
179     auto negate2 = builder.AddInstruction(
180         HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
181 
182     auto m = CreateNewVerifiedModule();
183     HloComputation* computation = m->AddEntryComputation(builder.Build());
184 
185     auto fusion = computation->CreateFusionInstruction(
186         {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
187 
188     Layout layout = LayoutUtil::MakeLayout(minor_to_major);
189     Shape shape(ashape);
190     *shape.mutable_layout() = layout;
191     const ShapeLayout shape_layout(shape);
192 
193     ComputationLayout computation_layout(computation->ComputeProgramShape());
194     *computation_layout.mutable_result_layout() = shape_layout;
195 
196     AssignLayouts(m.get(), &computation_layout);
197 
198     EXPECT_TRUE(LayoutUtil::Equal(
199         layout, fusion->fused_parameter(0)->shape().layout()));
200     EXPECT_TRUE(LayoutUtil::Equal(
201         layout, fusion->fused_parameter(1)->shape().layout()));
202     EXPECT_TRUE(LayoutUtil::Equal(
203         layout, fusion->fused_expression_root()->shape().layout()));
204 
205     // Inner fused node should not have layout.
206     EXPECT_FALSE(LayoutUtil::HasLayout(
207         fusion->fused_expression_root()->operand(0)->shape()));
208   }
209 }
210 
TEST_F(LayoutAssignmentTest,TupleLayout)211 TEST_F(LayoutAssignmentTest, TupleLayout) {
212   // Verify the layouts of a tuple are assigned properly (the element layouts
213   // match their source).
214   auto builder = HloComputation::Builder(TestName());
215   auto constant0 = builder.AddInstruction(
216       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
217           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
218   auto constant1 = builder.AddInstruction(
219       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
220           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
221   auto tuple = builder.AddInstruction(
222       HloInstruction::CreateTuple({constant0, constant1}));
223 
224   // To avoid having to construct a tuple layout in the ComputationLayout below,
225   // make the result of the instruction be an array.
226   auto get_element0 = builder.AddInstruction(
227       HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
228   auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
229       constant0->shape(), HloOpcode::kNegate, get_element0));
230 
231   auto m = CreateNewVerifiedModule();
232   m->AddEntryComputation(builder.Build());
233 
234   ComputationLayout computation_layout(
235       m->entry_computation()->ComputeProgramShape());
236 
237   AssignLayouts(m.get(), &computation_layout);
238 
239   EXPECT_TRUE(
240       LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
241 
242   EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
243   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
244       negate->shape(), computation_layout.result_layout().shape()));
245   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
246       ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
247 }
248 
TEST_F(LayoutAssignmentTest,TupleSelect)249 TEST_F(LayoutAssignmentTest, TupleSelect) {
250   // Verify layouts of a select with tuple operands is assigned properly.
251   auto builder = HloComputation::Builder(TestName());
252   auto constant0 = builder.AddInstruction(
253       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
254           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
255   auto constant1 = builder.AddInstruction(
256       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
257           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
258   auto tuple0 = builder.AddInstruction(
259       HloInstruction::CreateTuple({constant0, constant1}));
260   auto tuple1 = builder.AddInstruction(
261       HloInstruction::CreateTuple({constant0, constant1}));
262 
263   auto pred = builder.AddInstruction(
264       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
265 
266   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
267       tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
268 
269   auto m = CreateNewVerifiedModule();
270   m->AddEntryComputation(builder.Build());
271 
272   ComputationLayout computation_layout(
273       m->entry_computation()->ComputeProgramShape());
274   Shape result_shape =
275       ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
276   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
277       result_shape));
278 
279   AssignLayouts(m.get(), &computation_layout);
280 
281   EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
282 }
283 
TEST_F(LayoutAssignmentTest,ConflictingLayoutTuple)284 TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
285   // Construct following computation which has conflicting layouts for two
286   // elements of a tuple which share the same source logicalb buffer:
287   //
288   // %constant = Constant(...)
289   // %inner_tuple = Tuple(%constant)
290   // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
291   //
292   // Result layout col-major for the first element and row-major for the
293   // second. This results in the conflict where the element of the inner_tuple
294   // needs to be both col and row major. This is resolved by deep-copying the
295   // tuple and assigning the layouts of the copied arrays as needed.
296   auto builder = HloComputation::Builder(TestName());
297   auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
298       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
299   auto inner_tuple =
300       builder.AddInstruction(HloInstruction::CreateTuple({constant}));
301   auto nested_tuple = builder.AddInstruction(
302       HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
303 
304   auto m = CreateNewVerifiedModule();
305   m->AddEntryComputation(builder.Build());
306 
307   ComputationLayout computation_layout(
308       m->entry_computation()->ComputeProgramShape());
309   Shape result_shape = nested_tuple->shape();
310   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
311       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
312   *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
313       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
314   TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
315       result_shape));
316 
317   LayoutAssignment layout_assignment(&computation_layout);
318   AssignLayouts(m.get(), &computation_layout);
319 
320   // Layout assignment should have deep copied the result of the computation to
321   // address the layout conflict. This results in several Tuple() and
322   // GetTupleElement() instructions. Running algebraic simplification should
323   // clean up the code to something like:
324   //
325   //  %constant = Constant(...) layout={1,0}
326   //  %tuple.0 = Tuple(%constant) layout=({1,0})
327   //  %copy = Copy(%constant) layout={0,1}  # layout transposed
328   //  %tuple.1 = Tuple(%copy) layout=({0,1})
329   //  %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
330   //
331   AlgebraicSimplifierOptions options(
332       [](const Shape&, const Shape&) { return false; });
333   options.set_is_layout_sensitive(true);
334   EXPECT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie());
335   HloInstruction* root = m->entry_computation()->root_instruction();
336   // Verify layout of the root and the root's operands.
337   EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
338   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
339                                root->operand(0)->shape()));
340   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
341                                root->operand(1)->shape()));
342 
343   // Verify the structure of the HLO graph.
344   EXPECT_THAT(root,
345               GmockMatch(m::Tuple(m::Tuple(m::Op().Is(constant)),
346                                   m::Tuple(m::Copy(m::Op().Is(constant))))));
347 }
348 
TEST_F(LayoutAssignmentTest,ElementwiseAndReshape)349 TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
350   // param -> log -> reshape -> tanh
351   auto builder = HloComputation::Builder(TestName());
352   Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
353   Shape bshape = ShapeUtil::MakeShape(F32, {3, 1, 2});
354   auto param = builder.AddInstruction(
355       HloInstruction::CreateParameter(0, ashape, "param"));
356   auto log = builder.AddInstruction(
357       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
358   auto reshape =
359       builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
360   auto tanh = builder.AddInstruction(
361       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
362 
363   auto m = CreateNewVerifiedModule();
364   HloComputation* computation = m->AddEntryComputation(builder.Build(tanh));
365 
366   Shape ashape_with_layout(ashape);
367   Shape bshape_with_layout(bshape);
368   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 2, 1, 3});
369   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
370 
371   ComputationLayout computation_layout(computation->ComputeProgramShape());
372   *computation_layout.mutable_parameter_layout(0) =
373       ShapeLayout(ashape_with_layout);
374   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
375   AssignLayouts(m.get(), &computation_layout);
376 
377   auto log_minor_to_major =
378       AsInt64Slice(log->shape().layout().minor_to_major());
379   EXPECT_GT(PositionInContainer(log_minor_to_major, 1),
380             PositionInContainer(log_minor_to_major, 2));
381 
382   auto reshape_minor_to_major =
383       AsInt64Slice(reshape->shape().layout().minor_to_major());
384   EXPECT_GT(PositionInContainer(reshape_minor_to_major, 0),
385             PositionInContainer(reshape_minor_to_major, 2));
386 }
387 
388 // Test whether LayoutAssignment assigns layouts to elementwise operations to
389 // keep linear indices valid across them, and to transpositions to make them
390 // bitcasts.
TEST_F(LayoutAssignmentTest,ElementwiseAndTranspose)391 TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
392   // param -> log -> transpose -> tanh
393   auto builder = HloComputation::Builder(TestName());
394   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
395   Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
396   auto param = builder.AddInstruction(
397       HloInstruction::CreateParameter(0, ashape, "param"));
398   auto log = builder.AddInstruction(
399       HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
400   auto transpose = builder.AddInstruction(
401       HloInstruction::CreateTranspose(bshape, log, {1, 0}));
402   auto tanh = builder.AddInstruction(
403       HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
404   auto m = CreateNewVerifiedModule();
405   auto computation = m->AddEntryComputation(builder.Build(tanh));
406 
407   Shape ashape_with_layout(ashape);
408   Shape bshape_with_layout(bshape);
409   *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
410   *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
411 
412   ComputationLayout computation_layout(computation->ComputeProgramShape());
413   *computation_layout.mutable_parameter_layout(0) =
414       ShapeLayout(ashape_with_layout);
415   *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
416   AssignLayouts(m.get(), &computation_layout);
417 
418   EXPECT_TRUE(
419       LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
420   EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
421                                 transpose->shape().layout()));
422   EXPECT_TRUE(
423       LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
424 }
425 
426 // Test whether LayoutAssignment assigns layouts to transpositions to make them
427 // bitcasts.
TEST_F(LayoutAssignmentTest,BroadcastAndTranspose)428 TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
429   // param -> broadcast -> transpose
430   auto builder = HloComputation::Builder(TestName());
431   Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
432   Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
433   Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
434   auto param = builder.AddInstruction(
435       HloInstruction::CreateParameter(0, ashape, "param"));
436   auto broadcast = builder.AddInstruction(
437       HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
438   auto transpose = builder.AddInstruction(
439       HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
440   auto m = CreateNewVerifiedModule();
441   HloComputation* computation =
442       m->AddEntryComputation(builder.Build(transpose));
443 
444   Shape input_shape_with_layout(ashape);
445   Shape output_shape_with_layout(cshape);
446   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
447   *output_shape_with_layout.mutable_layout() =
448       LayoutUtil::MakeLayout({2, 1, 0});
449 
450   ComputationLayout computation_layout(computation->ComputeProgramShape());
451   *computation_layout.mutable_parameter_layout(0) =
452       ShapeLayout(input_shape_with_layout);
453   *computation_layout.mutable_result_layout() =
454       ShapeLayout(output_shape_with_layout);
455   AssignLayouts(m.get(), &computation_layout);
456 
457   EXPECT_THAT(broadcast->shape().layout().minor_to_major(),
458               ElementsAre(0, 1, 2));
459 }
460 
TEST_F(LayoutAssignmentTest,ReshapeOperandHasMultipleUsers)461 TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
462   // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
463   //                            \                                     /
464   //                             \-> tanh[3x4] -> broadcast2[2x3x4] -/
465   //
466   // The layout of `transpose` is set to {1,0} because it provides a buffer to
467   // the computation result which has a fixed layout.. Therefore, `broadcast`
468   // (the operand of transpose) is expected to have layout {0,1} so that the
469   // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
470   // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
471   Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
472   Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
473   Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
474   Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
475 
476   auto builder = HloComputation::Builder(TestName());
477   auto param = builder.AddInstruction(
478       HloInstruction::CreateParameter(0, f32_4, "param"));
479   auto broadcast = builder.AddInstruction(
480       HloInstruction::CreateBroadcast(f32_34, param, {1}));
481   auto transpose = builder.AddInstruction(
482       HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
483   auto tanh = builder.AddInstruction(
484       HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
485   auto broadcast2 = builder.AddInstruction(
486       HloInstruction::CreateBroadcast(f32_234, tanh, {1, 2}));
487   auto tuple = builder.AddInstruction(
488       HloInstruction::CreateTuple({transpose, broadcast2}));
489   auto m = CreateNewVerifiedModule();
490   HloComputation* computation = m->AddEntryComputation(builder.Build(tuple));
491 
492   ComputationLayout computation_layout(computation->ComputeProgramShape());
493   Shape param_shape_with_layout(f32_4);
494   Shape transpose_shape_with_layout(f32_43);
495   Shape broadcast2_shape_with_layout(f32_234);
496   *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
497   *transpose_shape_with_layout.mutable_layout() =
498       LayoutUtil::MakeLayout({1, 0});
499   *broadcast2_shape_with_layout.mutable_layout() =
500       LayoutUtil::MakeLayout({2, 1, 0});
501 
502   *computation_layout.mutable_parameter_layout(0) =
503       ShapeLayout(param_shape_with_layout);
504   *computation_layout.mutable_result_layout() =
505       ShapeLayout(ShapeUtil::MakeTupleShape(
506           {transpose_shape_with_layout, broadcast2_shape_with_layout}));
507   AssignLayouts(m.get(), &computation_layout);
508 
509   EXPECT_THAT(broadcast->shape().layout().minor_to_major(), ElementsAre(0, 1));
510   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(1, 0));
511   EXPECT_THAT(tanh->shape().layout().minor_to_major(), ElementsAre(0, 1));
512 }
513 
514 class OperandsMustBeTheSameLayoutAssignment : public LayoutAssignment {
515  public:
OperandsMustBeTheSameLayoutAssignment(ComputationLayout * entry_computation_layout)516   explicit OperandsMustBeTheSameLayoutAssignment(
517       ComputationLayout* entry_computation_layout)
518       : LayoutAssignment(entry_computation_layout) {}
519 
520  protected:
PropagateBufferConstraint(const BufferLayoutConstraint & buffer_constraint,LayoutConstraints * constraints)521   Status PropagateBufferConstraint(
522       const BufferLayoutConstraint& buffer_constraint,
523       LayoutConstraints* constraints) override {
524     const LogicalBuffer& buffer = buffer_constraint.buffer();
525     const HloInstruction* instruction = buffer.instruction();
526 
527     // Force the operands' layout to the output layout.
528     for (int64 operand_no = 0; operand_no < instruction->operand_count();
529          ++operand_no) {
530       const HloInstruction* operand = instruction->operand(operand_no);
531       if (instruction->shape().rank() != operand->shape().rank()) {
532         continue;
533       }
534       TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
535           buffer_constraint.layout(), instruction, operand_no,
536           /*mandatory=*/true));
537     }
538     return PropagateBufferConstraintToUses(buffer_constraint, constraints);
539   }
540 };
541 
TEST_F(LayoutAssignmentTest,MakeOperandsTheSame)542 TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
543   // param0 -> concatenate -> reshape
544   // param1   -^
545   auto builder = HloComputation::Builder(TestName());
546   Shape ashape = ShapeUtil::MakeShape(F32, {50, 1});
547   Shape bshape = ShapeUtil::MakeShape(F32, {50, 2});
548   Shape cshape = ShapeUtil::MakeShape(F32, {100});
549   auto param0 = builder.AddInstruction(
550       HloInstruction::CreateParameter(0, ashape, "param"));
551   auto param1 = builder.AddInstruction(
552       HloInstruction::CreateParameter(1, ashape, "param"));
553   auto concatenate = builder.AddInstruction(
554       HloInstruction::CreateConcatenate(bshape, {param0, param1}, 1));
555   auto reshape = builder.AddInstruction(
556       HloInstruction::CreateReshape(cshape, concatenate));
557   auto m = CreateNewVerifiedModule();
558   HloComputation* computation = m->AddEntryComputation(builder.Build(reshape));
559 
560   Shape param0_shape_with_layout(ashape);
561   Shape param1_shape_with_layout(ashape);
562   *param0_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
563   *param1_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
564 
565   ComputationLayout computation_layout(computation->ComputeProgramShape());
566   *computation_layout.mutable_parameter_layout(0) =
567       ShapeLayout(param0_shape_with_layout);
568   *computation_layout.mutable_parameter_layout(1) =
569       ShapeLayout(param1_shape_with_layout);
570   OperandsMustBeTheSameLayoutAssignment layout_assignment(&computation_layout);
571   EXPECT_IS_OK(layout_assignment.Run(m.get()).status());
572 
573   EXPECT_EQ(HloOpcode::kCopy, concatenate->operand(0)->opcode());
574   EXPECT_THAT(concatenate->operand(0)->shape().layout().minor_to_major(),
575               ElementsAre(1, 0));
576   EXPECT_THAT(concatenate->operand(1)->shape().layout().minor_to_major(),
577               ElementsAre(1, 0));
578   EXPECT_THAT(concatenate->shape().layout().minor_to_major(),
579               ElementsAre(1, 0));
580 }
581 
582 // Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest,TransposeToBitcastFromOperand)583 TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
584   auto builder = HloComputation::Builder(TestName());
585   Shape input_shape_with_layout =
586       ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
587   auto param = builder.AddInstruction(
588       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
589   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
590       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
591   auto m = CreateNewVerifiedModule();
592   HloComputation* computation =
593       m->AddEntryComputation(builder.Build(transpose));
594   ComputationLayout computation_layout(computation->ComputeProgramShape());
595   AssignLayouts(m.get(), &computation_layout);
596   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
597                                             transpose->shape(), {2, 3, 0, 1}));
598 }
599 // Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest,TransposeToBitcastToUser)600 TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
601   auto builder = HloComputation::Builder(TestName());
602   Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
603   auto constant = builder.AddInstruction(
604       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
605   auto broadcast = builder.AddInstruction(
606       HloInstruction::CreateBroadcast(input_shape, constant, {}));
607   auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
608       ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
609   auto m = CreateNewVerifiedModule();
610   HloComputation* computation =
611       m->AddEntryComputation(builder.Build(transpose));
612   ComputationLayout computation_layout(computation->ComputeProgramShape());
613   AssignLayouts(m.get(), &computation_layout);
614   EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
615                                             transpose->shape(), {2, 3, 0, 1}));
616 }
617 
618 // TransposeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,TransposeIsBitcastFail)619 TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) {
620   auto builder = HloComputation::Builder(TestName());
621   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
622   Shape input_shape_with_layout(input_shape);
623   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
624   auto param = builder.AddInstruction(
625       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
626   auto hlo = builder.AddInstruction(
627       HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1}));
628   // Clear the default layout assigned to the instruction.
629   LayoutUtil::ClearLayout(hlo->mutable_shape());
630   EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(),
631                                              hlo->shape(), hlo->dimensions()),
632                "LayoutUtil::HasLayout");
633 }
634 
635 // ReshapeIsBitcast shouldn't be called without layout information.
TEST_F(LayoutAssignmentTest,ReshapeIsBitcastFail)636 TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) {
637   auto builder = HloComputation::Builder(TestName());
638   Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
639   Shape input_shape_with_layout(input_shape);
640   *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0});
641   auto param = builder.AddInstruction(
642       HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
643   auto hlo =
644       builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param));
645   // Clear the default layout assigned to the instruction.
646   LayoutUtil::ClearLayout(hlo->mutable_shape());
647   EXPECT_DEATH(
648       ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()),
649       "LayoutUtil::HasLayout");
650 }
651 
652 // Check that the computation below doesn't crash the compiler.
653 //
654 // Within a fusion computation, only the parameters and result get assigned a
655 // layout.  When we run the algebraic simplifier on this computation post layout
656 // assignment, it should not call TransposeIsBitcast on the `transpose` node
657 // inside the fusion computation as TransposeIsBitcast checks both input_shape
658 // and output_shape have layouts.
TEST_F(LayoutAssignmentTest,TransposeWithinFusionDoesNotCrash)659 TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
660   const char* module_str = R"(
661     HloModule test_module
662 
663     fused_computation {
664       param_1 = f32[2,2,2]{2,1,0} parameter(1)
665       transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1}
666       reduce_1 = f32[] parameter(0)
667       broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={}
668       ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1)
669     }
670 
671     ENTRY entry_computation {
672       fusion.1 = f32[2,2,2]{2,1,0} parameter(1)
673       reduce.1 = f32[] parameter(0)
674       fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation
675      ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2)
676     }
677   )";
678 
679   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
680                           ParseAndReturnVerifiedModule(module_str));
681   std::unique_ptr<HloModule> compiled_module =
682       backend()
683           .compiler()
684           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
685                          /*device_allocator=*/nullptr)
686           .ConsumeValueOrDie();
687 
688   EXPECT_EQ(Status::OK(), backend()
689                               .compiler()
690                               ->RunBackend(std::move(compiled_module),
691                                            backend().default_stream_executor(),
692                                            /*device_allocator=*/nullptr)
693                               .status());
694 }
695 
696 // A GTE inside of a fusion node inherits the layout of its operand (which
697 // should, if we keep following operands, eventually be a parameter).
TEST_F(LayoutAssignmentTest,GTEInheritsLayoutFromOperand)698 TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
699   const char* module_str = R"(
700     HloModule test_module
701 
702     fused_computation {
703       fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
704       gte0 = f32[2,2,2] get-tuple-element(fparam), index=0
705       gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1
706       gte1a = f32[2,2,2] get-tuple-element(gte1), index=0
707       gte1b = f32[2,2,2] get-tuple-element(gte1), index=1
708       add = f32[2,2,2] add(gte1a, gte1b)
709       ROOT fresult = f32[2,2,2] add(gte0, add)
710     }
711 
712     ENTRY entry_computation {
713       param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0)
714       ROOT fusion =
715         f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation
716     }
717   )";
718 
719   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
720                           ParseAndReturnVerifiedModule(module_str));
721   ComputationLayout computation_layout(
722       m->entry_computation()->ComputeProgramShape());
723   Shape param_shape = ShapeUtil::MakeTupleShape(
724       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}),
725        ShapeUtil::MakeTupleShape({
726            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}),
727            ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}),
728        })});
729   TF_ASSERT_OK(
730       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
731           param_shape));
732   computation_layout.mutable_result_layout()->ResetLayout(
733       LayoutUtil::MakeLayout({2, 1, 0}));
734   AssignLayouts(m.get(), &computation_layout);
735 
736   EXPECT_THAT(LayoutOf(m.get(), "gte0"), ElementsAre(0, 1, 2));
737   EXPECT_THAT(LayoutOf(m.get(), "gte1a"), ElementsAre(1, 2, 0));
738   EXPECT_THAT(LayoutOf(m.get(), "gte1b"), ElementsAre(2, 0, 1));
739   EXPECT_THAT(LayoutOf(m.get(), "fresult"), ElementsAre(2, 1, 0));
740   EXPECT_THAT(FindInstruction(m.get(), "gte1")
741                   ->shape()
742                   .tuple_shapes(0)
743                   .layout()
744                   .minor_to_major(),
745               ElementsAre(1, 2, 0));
746   EXPECT_THAT(FindInstruction(m.get(), "gte1")
747                   ->shape()
748                   .tuple_shapes(1)
749                   .layout()
750                   .minor_to_major(),
751               ElementsAre(2, 0, 1));
752 }
753 
TEST_F(LayoutAssignmentTest,ConditionalAsymmetricLayout)754 TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
755   auto builder = HloComputation::Builder(TestName());
756   auto m = CreateNewVerifiedModule();
757   Shape shape = ShapeUtil::MakeShape(F32, {128, 8});
758   Shape tshape = ShapeUtil::MakeTupleShape({shape, shape});
759   Shape result_tshape = ShapeUtil::MakeTupleShape({shape});
760 
761   auto param0 = builder.AddInstruction(
762       HloInstruction::CreateParameter(0, shape, "param0"));
763   auto param1 = builder.AddInstruction(
764       HloInstruction::CreateParameter(1, shape, "param1"));
765   auto pred = builder.AddInstruction(HloInstruction::CreateParameter(
766       2, ShapeUtil::MakeShape(PRED, {}), "param2"));
767   auto tuple =
768       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
769 
770   auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch");
771   {
772     auto param = true_builder.AddInstruction(
773         HloInstruction::CreateParameter(0, tshape, "param"));
774     auto gte0 = true_builder.AddInstruction(
775         HloInstruction::CreateGetTupleElement(shape, param, 0));
776     auto gte1 = true_builder.AddInstruction(
777         HloInstruction::CreateGetTupleElement(shape, param, 1));
778     auto add = true_builder.AddInstruction(
779         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1));
780     true_builder.AddInstruction(HloInstruction::CreateTuple({add}));
781   }
782   HloComputation* true_computation =
783       m->AddEmbeddedComputation(true_builder.Build());
784 
785   auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch");
786   {
787     Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1});
788     false_builder.AddInstruction(
789         HloInstruction::CreateParameter(0, tshape, "param"));
790     // Using infeed as layout assignment does not mess up with it.
791     auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
792     auto infeed = false_builder.AddInstruction(
793         HloInstruction::CreateInfeed(xshape, token, ""));
794     auto infeed_data = false_builder.AddInstruction(
795         HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
796     false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
797   }
798   HloComputation* false_computation =
799       m->AddEmbeddedComputation(false_builder.Build());
800   builder.AddInstruction(HloInstruction::CreateConditional(
801       result_tshape, pred, tuple, true_computation, tuple, false_computation));
802 
803   HloComputation* computation = m->AddEntryComputation(builder.Build());
804   ComputationLayout computation_layout(computation->ComputeProgramShape());
805 
806   AssignLayouts(m.get(), &computation_layout);
807 
808   const HloInstruction* true_root = true_computation->root_instruction();
809   const HloInstruction* false_root = false_computation->root_instruction();
810   EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple);
811   EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple);
812 
813   const HloInstruction* true_result = true_root->operand(0);
814   const HloInstruction* false_result = false_root->operand(0);
815   EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(),
816                                 false_result->shape().layout()));
817   EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
818 }
819 
TEST_F(LayoutAssignmentTest,InternalErrorOnBitcast)820 TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
821   auto builder = HloComputation::Builder(TestName());
822   auto constant0 = builder.AddInstruction(
823       HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
824           {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
825   builder.AddInstruction(HloInstruction::CreateUnary(
826       constant0->shape(), HloOpcode::kBitcast, constant0));
827   auto m = CreateNewVerifiedModule();
828   m->AddEntryComputation(builder.Build());
829 
830   ComputationLayout computation_layout(
831       m->entry_computation()->ComputeProgramShape());
832   LayoutAssignment layout_assignment(&computation_layout);
833   Status error_status = layout_assignment.Run(m.get()).status();
834   EXPECT_FALSE(error_status.ok());
835   EXPECT_THAT(
836       error_status.error_message(),
837       ::testing::HasSubstr(
838           "Unexpected bitcast operation seen during layout assignment"));
839 }
840 
TEST_F(LayoutAssignmentTest,ChannelLayoutMismatch)841 TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
842   // Pin non matching layouts to parameter and root.
843   const char* module_str = R"(
844     HloModule test_module
845 
846     ENTRY entry_computation {
847       param = (f32[2,2]) parameter(0)
848       gte = f32[2,2] get-tuple-element(param), index=0
849       token0 = token[] after-all()
850       recv = (f32[2,2], u32[], token[]) recv(token0), channel_id=1, sharding={maximal device=1}
851       recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
852         sharding={maximal device=1}
853       ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
854       send = (f32[2,2], u32[], token[]) send(gte, token0), channel_id=1,
855         sharding={maximal device=0}
856       send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
857     }
858   )";
859 
860   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
861                           ParseAndReturnVerifiedModule(module_str));
862   ComputationLayout computation_layout(
863       m->entry_computation()->ComputeProgramShape());
864   Shape param_shape = ShapeUtil::MakeTupleShape(
865       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
866   TF_ASSERT_OK(
867       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
868           param_shape));
869   computation_layout.mutable_result_layout()->ResetLayout(
870       LayoutUtil::MakeLayout({1, 0}));
871 
872   ChannelLayoutConstraints channel_constraints;
873   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
874 
875   EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
876   EXPECT_THAT(LayoutOf(m.get(), "root"), ElementsAre(1, 0));
877   EXPECT_TRUE(ShapeUtil::Equal(
878       ShapeUtil::GetSubshape(FindInstruction(m.get(), "send")->shape(), {0}),
879       ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
880 }
881 
TEST_F(LayoutAssignmentTest,AllReduceLayoutMissmatch)882 TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
883   // Pin non matching layouts to parameter and root.
884   const char* module_str = R"(
885     HloModule test_module
886 
887     add {
888       lhs = f32[] parameter(0)
889       rhs = f32[] parameter(1)
890       ROOT add = f32[] add(lhs, rhs)
891     }
892 
893     ENTRY entry_computation {
894       param = (f32[2,2]) parameter(0)
895       gte = f32[2,2] get-tuple-element(param), index=0
896       ar.0 = f32[2,2] all-reduce(gte),
897         all_reduce_id=1, replica_groups={{0}}, to_apply=add,
898         sharding={maximal device=0}
899       const = f32[2,2] constant({{0,1},{2,3}})
900       ROOT ar.1 = f32[2,2] all-reduce(const),
901         all_reduce_id=1, replica_groups={{0}}, to_apply=add,
902         sharding={maximal device=1}
903     })";
904   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
905                           ParseAndReturnVerifiedModule(module_str));
906   ComputationLayout computation_layout(
907       m->entry_computation()->ComputeProgramShape());
908   Shape param_shape = ShapeUtil::MakeTupleShape(
909       {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
910   TF_ASSERT_OK(
911       computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
912           param_shape));
913   computation_layout.mutable_result_layout()->ResetLayout(
914       LayoutUtil::MakeLayout({1, 0}));
915 
916   ChannelLayoutConstraints channel_constraints;
917   AssignLayouts(m.get(), &computation_layout, &channel_constraints);
918 
919   EXPECT_THAT(LayoutOf(m.get(), "gte"), ElementsAre(0, 1));
920   EXPECT_THAT(LayoutOf(m.get(), "ar.0"), ElementsAre(0, 1));
921   EXPECT_THAT(LayoutOf(m.get(), "ar.1"), ElementsAre(0, 1));
922   const HloInstruction* root = m->entry_computation()->root_instruction();
923   EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
924 }
925 
TEST_F(LayoutAssignmentTest,CopySliceOperandToAvoidImplicitLayoutChange)926 TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
927   const char* module_str = R"(
928     HloModule CopySliceOperandToAvoidImplicitLayoutChange
929 
930     ENTRY CopySliceOperandToAvoidImplicitLayoutChange {
931       par0 = f32[3,4]{1,0} parameter(0)
932       par1 = f32[4,5]{0,1} parameter(1)
933       slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]}
934       ROOT add0 = f32[3,4]{1,0} add(par0,slice0)
935     }
936   )";
937 
938   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
939                           ParseAndReturnVerifiedModule(module_str));
940   auto compiled_module =
941       backend()
942           .compiler()
943           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
944                          /*device_allocator=*/nullptr)
945           .ConsumeValueOrDie();
946   HloInstruction* root =
947       compiled_module->entry_computation()->root_instruction();
948   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
949   EXPECT_THAT(
950       root,
951       GmockMatch(m::Add(
952           m::Parameter(),
953           m::Slice(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy)))));
954 }
955 
TEST_F(LayoutAssignmentTest,CopyDSliceOperandToAvoidImplicitLayoutChange)956 TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) {
957   const char* module_str = R"(
958     HloModule CopyDSliceOperandToAvoidImplicitLayoutChange
959 
960     ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange {
961       par0 = f32[3,4]{1,0} parameter(0)
962       par1 = f32[4,5]{0,1} parameter(1)
963       par2 = s32[] parameter(2)
964       par3 = s32[] parameter(3)
965       dslice0 = f32[3,4] dynamic-slice(par1, par2, par3), dynamic_slice_sizes={3,4}
966       ROOT add0 = f32[3,4]{1,0} add(par0,dslice0)
967     }
968   )";
969 
970   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
971                           ParseAndReturnVerifiedModule(module_str));
972   auto compiled_module =
973       backend()
974           .compiler()
975           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
976                          /*device_allocator=*/nullptr)
977           .ConsumeValueOrDie();
978   HloInstruction* root =
979       compiled_module->entry_computation()->root_instruction();
980   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0});
981   EXPECT_THAT(root,
982               GmockMatch(m::Add(
983                   m::Parameter(),
984                   m::DynamicSlice(
985                       m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
986                       m::Parameter(2), m::Parameter(3)))));
987 }
988 
TEST_F(LayoutAssignmentTest,CopyConcatOperandToAvoidImplicitLayoutChange)989 TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) {
990   const char* module_str = R"(
991     HloModule CopyConcatOperandToAvoidImplicitLayoutChange
992 
993     ENTRY CopyConcatOperandToAvoidImplicitLayoutChange {
994       par0 = f32[3,8]{1,0} parameter(0)
995       par1 = f32[3,5]{0,1} parameter(1)
996       par2 = f32[3,3]{1,0} parameter(2)
997       concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2),
998         dimensions={1}
999       ROOT add0 = f32[3,8]{1,0} add(par0,concat0)
1000     }
1001   )";
1002 
1003   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1004                           ParseAndReturnVerifiedModule(module_str));
1005   auto compiled_module =
1006       backend()
1007           .compiler()
1008           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1009                          /*device_allocator=*/nullptr)
1010           .ConsumeValueOrDie();
1011   HloInstruction* root =
1012       compiled_module->entry_computation()->root_instruction();
1013   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0});
1014   EXPECT_THAT(
1015       root,
1016       GmockMatch(m::Add(
1017           m::Parameter(),
1018           m::Concatenate(m::Copy(m::Parameter(1)).WithShapeEqualTo(&shape_copy),
1019                          m::Parameter(2)))));
1020 }
1021 
TEST_F(LayoutAssignmentTest,ConvolutionOperandWithImplicitLayoutChangeNotCopied)1022 TEST_F(LayoutAssignmentTest,
1023        ConvolutionOperandWithImplicitLayoutChangeNotCopied) {
1024   const char* module_str = R"(
1025     HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied
1026 
1027     ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied {
1028       par0 = f32[128,3,230,230]{2,3,1,0} parameter(0)
1029       par1 = f32[7,7,3,64]{3,2,0,1} parameter(1)
1030       ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1),
1031         window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01,
1032         feature_group_count=1
1033     }
1034   )";
1035 
1036   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1037                           ParseAndReturnVerifiedModule(module_str));
1038   auto compiled_module =
1039       backend()
1040           .compiler()
1041           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1042                          /*device_allocator=*/nullptr)
1043           .ConsumeValueOrDie();
1044   HloInstruction* root =
1045       compiled_module->entry_computation()->root_instruction();
1046   EXPECT_THAT(root,
1047               GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
1048 }
1049 
TEST_F(LayoutAssignmentTest,PropagatingLayoutFromResultToOperand)1050 TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) {
1051   const char* module_str = R"(
1052     HloModule PropagatingLayoutFromResultToOperand
1053 
1054     ENTRY PropagatingLayoutFromResultToOperand {
1055       par0 = f32[4,5]{1,0} parameter(0)
1056       ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]}
1057     }
1058   )";
1059 
1060   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1061                           ParseAndReturnVerifiedModule(module_str));
1062   auto compiled_module =
1063       backend()
1064           .compiler()
1065           ->RunHloPasses(m->Clone(), backend().default_stream_executor(),
1066                          /*device_allocator=*/nullptr)
1067           .ConsumeValueOrDie();
1068   HloInstruction* root =
1069       compiled_module->entry_computation()->root_instruction();
1070   Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1});
1071   EXPECT_THAT(root,
1072               GmockMatch(m::Slice(
1073                   m::Copy(m::Parameter(0)).WithShapeEqualTo(&shape_copy))));
1074 }
1075 
TEST_F(LayoutAssignmentTest,TupleCopyOnLayoutMismatch)1076 TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
1077   // The first infeed uses layout {0,1}, while the second uses layout {1,0}.
1078   // The mismatch forces a copy of the tuple.  The tuple contains a token, so
1079   // layout assignment will fail if it tries to copy the whole tuple.
1080   const char* module_str = R"(
1081     HloModule TupleCopyOnLayoutMismatch
1082 
1083     condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] {
1084       tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1085       counter.1 = s32[] get-tuple-element(tup.1), index=0
1086       five = s32[] constant(5)
1087       ROOT lt = pred[] compare(counter.1, five), direction=LT
1088     }
1089 
1090     body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) {
1091       tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0)
1092       counter.2 = s32[] get-tuple-element(tup.2), index=0
1093       tok.2 = token[] get-tuple-element(tup.2), index=1
1094 
1095       ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2)
1096       next_tok = token[] get-tuple-element(ifeed.2), index=1
1097       next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0
1098 
1099       one = s32[] constant(1)
1100       next_counter = s32[] add(counter.2, one)
1101       ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf)
1102     }
1103 
1104     ENTRY main () -> f32[512,1024]{0,1} {
1105       start_tok = token[] after-all()
1106 
1107       ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok)
1108       itok = token[] get-tuple-element(ifeed.3), index=1
1109       ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0
1110 
1111       zero = s32[] constant(0)
1112       itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf)
1113 
1114       loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2
1115       ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2
1116     }
1117   )";
1118 
1119   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
1120                           ParseAndReturnVerifiedModule(module_str));
1121   ComputationLayout computation_layout(
1122       m->entry_computation()->ComputeProgramShape());
1123 
1124   // Sanity check to verify that there's a layout mismatch.
1125   EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1126   EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1127 
1128   AssignLayouts(m.get(), &computation_layout);
1129 
1130   // Make sure that layout assignment did not magically eliminate the mismatch,
1131   // in which case the test didn't prove anything.
1132   EXPECT_THAT(LayoutOf(m.get(), "ibuf"), ElementsAre(0, 1));
1133   EXPECT_THAT(LayoutOf(m.get(), "next_buf"), ElementsAre(1, 0));
1134 }
1135 
TEST_F(LayoutAssignmentTest,CustomCallNotLayoutConstrained)1136 TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
1137   const char* module_str = R"(
1138 HloModule CustomCallNotLayoutConstrained
1139 
1140 ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
1141   %p = f32[42,2,3] parameter(0)
1142   ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
1143 }
1144 )";
1145   // Try with a couple different layouts. In each case the custom calls operand
1146   // and result layout should match that of the computation.
1147   {
1148     TF_ASSERT_OK_AND_ASSIGN(
1149         std::unique_ptr<VerifiedHloModule> m,
1150         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1151     ComputationLayout computation_layout = m->entry_computation_layout();
1152     *computation_layout.mutable_parameter_layout(0) =
1153         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
1154     *computation_layout.mutable_result_layout() = ShapeLayout(
1155         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
1156     AssignLayouts(m.get(), &computation_layout);
1157 
1158     HloInstruction* root = m->entry_computation()->root_instruction();
1159     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1160     ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
1161     ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
1162   }
1163   {
1164     TF_ASSERT_OK_AND_ASSIGN(
1165         std::unique_ptr<VerifiedHloModule> m,
1166         ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1167     ComputationLayout computation_layout = m->entry_computation_layout();
1168     *computation_layout.mutable_parameter_layout(0) =
1169         ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
1170     *computation_layout.mutable_result_layout() = ShapeLayout(
1171         ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
1172     AssignLayouts(m.get(), &computation_layout);
1173 
1174     HloInstruction* root = m->entry_computation()->root_instruction();
1175     ASSERT_THAT(root, GmockMatch(m::CustomCall(m::Parameter())));
1176     ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
1177     ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
1178   }
1179 }
1180 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrained)1181 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
1182   const char* module_str = R"(
1183 HloModule CustomCallLayoutConstrained
1184 
1185 ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1186   %p0 = f32[4,4] parameter(0)
1187   %p1 = f32[2,3] parameter(1)
1188   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
1189 }
1190 )";
1191   TF_ASSERT_OK_AND_ASSIGN(
1192       std::unique_ptr<VerifiedHloModule> m,
1193       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1194   ComputationLayout computation_layout = m->entry_computation_layout();
1195   *computation_layout.mutable_parameter_layout(0) =
1196       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1197   *computation_layout.mutable_parameter_layout(1) =
1198       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1199   *computation_layout.mutable_result_layout() = ShapeLayout(
1200       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1201   AssignLayouts(m.get(), &computation_layout);
1202 
1203   // The custom call should be partially encapsulated in kCopy instructions
1204   // because of the layout mismatches.
1205   ASSERT_THAT(m->entry_computation()->root_instruction(),
1206               GmockMatch(m::Copy(m::CustomCall(m::Copy(), m::Parameter()))));
1207 
1208   const HloInstruction* custom_call =
1209       m->entry_computation()->root_instruction()->operand(0);
1210   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1211   ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
1212   ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
1213 }
1214 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedZeroOperands)1215 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
1216   const char* module_str = R"(
1217 HloModule CustomCallLayoutConstrainedZeroOperands
1218 
1219 ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
1220   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
1221 }
1222 )";
1223   TF_ASSERT_OK_AND_ASSIGN(
1224       std::unique_ptr<VerifiedHloModule> m,
1225       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1226   ComputationLayout computation_layout = m->entry_computation_layout();
1227   *computation_layout.mutable_result_layout() = ShapeLayout(
1228       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1229   AssignLayouts(m.get(), &computation_layout);
1230 
1231   ASSERT_THAT(m->entry_computation()->root_instruction(),
1232               GmockMatch(m::Copy(m::CustomCall())));
1233 
1234   const HloInstruction* custom_call =
1235       m->entry_computation()->root_instruction()->operand(0);
1236   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1237 }
1238 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleOperand)1239 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
1240   const char* module_str = R"(
1241 HloModule CustomCallLayoutConstrainedTupleOperand
1242 
1243 ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
1244   %p0 = f32[4,4] parameter(0)
1245   %p1 = f32[2,3] parameter(1)
1246   %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
1247   ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
1248 }
1249 )";
1250   TF_ASSERT_OK_AND_ASSIGN(
1251       std::unique_ptr<VerifiedHloModule> m,
1252       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1253   ComputationLayout computation_layout = m->entry_computation_layout();
1254   *computation_layout.mutable_parameter_layout(0) =
1255       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1256   *computation_layout.mutable_parameter_layout(1) =
1257       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
1258   *computation_layout.mutable_result_layout() = ShapeLayout(
1259       ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
1260   AssignLayouts(m.get(), &computation_layout);
1261 
1262   HloInstruction* root = m->entry_computation()->root_instruction();
1263   ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
1264 
1265   ASSERT_THAT(m->entry_computation()->root_instruction(),
1266               GmockMatch(m::Copy(m::CustomCall(m::Tuple()))));
1267 
1268   const HloInstruction* custom_call =
1269       m->entry_computation()->root_instruction()->operand(0);
1270   ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
1271   ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
1272 }
1273 
TEST_F(LayoutAssignmentTest,CustomCallLayoutConstrainedTupleResult)1274 TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
1275   const char* module_str = R"(
1276 HloModule CustomCallLayoutConstrainedTupleResult
1277 
1278 ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
1279   %p0 = f32[4,4] parameter(0)
1280   ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
1281 }
1282 )";
1283   // Try with a couple different layouts. In each case the custom calls operand
1284   // and result layout should match that of the computation.
1285   TF_ASSERT_OK_AND_ASSIGN(
1286       std::unique_ptr<VerifiedHloModule> m,
1287       ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1288   ComputationLayout computation_layout = m->entry_computation_layout();
1289   *computation_layout.mutable_parameter_layout(0) =
1290       ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
1291   *computation_layout.mutable_result_layout() =
1292       ShapeLayout(ShapeUtil::MakeTupleShape(
1293           {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
1294            ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
1295   AssignLayouts(m.get(), &computation_layout);
1296 
1297   ExpectTupleLayoutIs(m->result_shape(), {{1, 0}, {1, 0}});
1298 
1299   const HloInstruction* custom_call = FindInstruction(m.get(), "custom-call");
1300   ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
1301 }
1302 
AssignLayoutsToComputation(HloModule * m,ChannelLayoutConstraints * channel_constraints=nullptr)1303 Status AssignLayoutsToComputation(
1304     HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
1305   if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {
1306     m->mutable_entry_computation_layout()
1307         ->mutable_result_layout()
1308         ->SetToDefaultLayout();
1309   }
1310   LayoutAssignment layout_assignment(
1311       m->mutable_entry_computation_layout(),
1312       LayoutAssignment::InstructionCanChangeLayout, channel_constraints);
1313   return layout_assignment.Run(m).status();
1314 }
1315 
TEST_F(LayoutAssignmentTest,OverwriteDiamondShapedConstraintsX)1316 TEST_F(LayoutAssignmentTest, OverwriteDiamondShapedConstraintsX) {
1317   // Check that we handle a diamond-shaped graph correctly.
1318   //      transpose
1319   //       /    \
1320   //     add    |
1321   //       \    /
1322   //        tuple
1323 
1324   auto b = HloComputation::Builder(TestName());
1325   Shape ashape = ShapeUtil::MakeShape(F32, {12, 8});
1326   Shape bshape = ShapeUtil::MakeShape(F32, {8, 12});
1327   auto param0 =
1328       b.AddInstruction(HloInstruction::CreateParameter(0, bshape, "input"));
1329   auto param1 =
1330       b.AddInstruction(HloInstruction::CreateParameter(1, ashape, "input"));
1331   auto transpose =
1332       b.AddInstruction(HloInstruction::CreateTranspose(ashape, param0, {1, 0}));
1333   auto add = b.AddInstruction(
1334       HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, transpose, param1));
1335   b.AddInstruction(HloInstruction::CreateTuple({add, transpose}));
1336   auto m = CreateNewVerifiedModule();
1337   m->AddEntryComputation(b.Build());
1338   Shape ashape_major = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {1, 0});
1339   Shape ashape_minor = ShapeUtil::MakeShapeWithLayout(F32, {12, 8}, {0, 1});
1340   *m->mutable_entry_computation_layout()->mutable_result_layout() =
1341       ShapeLayout(ShapeUtil::MakeTupleShape({ashape_major, ashape_minor}));
1342   const Layout r2_dim0major = LayoutUtil::MakeLayout({1, 0});
1343   ForceParameterLayout(m.get(), 0, r2_dim0major);
1344   ForceParameterLayout(m.get(), 1, r2_dim0major);
1345   TF_ASSERT_OK(AssignLayoutsToComputation(m.get()));
1346 
1347   EXPECT_THAT(add->shape().layout().minor_to_major(), ElementsAre(1, 0));
1348   EXPECT_THAT(add->operand(0)->shape().layout().minor_to_major(),
1349               ElementsAre(1, 0));
1350   EXPECT_THAT(add->operand(1)->shape().layout().minor_to_major(),
1351               ElementsAre(1, 0));
1352 
1353   EXPECT_THAT(transpose->shape().layout().minor_to_major(), ElementsAre(0, 1));
1354 }
1355 
1356 }  // namespace
1357 }  // namespace xla
1358