• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/computation_layout.h"
21 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
22 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_layout.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/status_matchers.h"
34 #include "tensorflow/stream_executor/lib/statusor.h"
35 
36 namespace xla {
37 namespace gpu {
38 namespace {
39 
40 namespace op = xla::testing::opcode_matchers;
41 using ::tensorflow::testing::IsOkAndHolds;
42 using ::testing::AllOf;
43 
44 using LayoutAssignmentTest = HloTestBase;
45 
TEST_F(LayoutAssignmentTest,Elementwise)46 TEST_F(LayoutAssignmentTest, Elementwise) {
47   Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
48   Shape ashape_in_row_major(ashape);
49   Shape ashape_in_col_major(ashape);
50   *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
51   *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
52 
53   // Enumerate all possible combinations of layouts.
54   for (const Shape& lhs_shape_with_layout :
55        {ashape_in_row_major, ashape_in_col_major}) {
56     for (const Shape& rhs_shape_with_layout :
57          {ashape_in_row_major, ashape_in_col_major}) {
58       for (const Shape& result_shape_with_layout :
59            {ashape_in_row_major, ashape_in_col_major}) {
60         // GpuLayoutAssignment should assign the same layout to "add" and its
61         // two operands.
62         auto builder = HloComputation::Builder(TestName());
63         auto x = builder.AddInstruction(
64             HloInstruction::CreateParameter(0, ashape, "x"));
65         auto y = builder.AddInstruction(
66             HloInstruction::CreateParameter(1, ashape, "y"));
67         auto add = builder.AddInstruction(
68             HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
69         auto module = CreateNewVerifiedModule();
70         HloComputation* computation =
71             module->AddEntryComputation(builder.Build(add));
72 
73         ComputationLayout computation_layout(
74             computation->ComputeProgramShape());
75         *computation_layout.mutable_parameter_layout(0) =
76             ShapeLayout(lhs_shape_with_layout);
77         *computation_layout.mutable_parameter_layout(1) =
78             ShapeLayout(rhs_shape_with_layout);
79         *computation_layout.mutable_result_layout() =
80             ShapeLayout(result_shape_with_layout);
81 
82         GpuLayoutAssignment layout_assignment(
83             &computation_layout, backend().default_stream_executor());
84         EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
85 
86         for (const HloInstruction* operand : add->operands()) {
87           EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
88                                         operand->shape().layout()));
89         }
90       }
91     }
92   }
93 }
94 
TEST_F(LayoutAssignmentTest,DotLayoutUnchangedIfValid)95 TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) {
96   const char* hlo_text = R"(
97   HloModule DotLayout
98   ENTRY dot {
99     p0 = f32[5,2,3]{1,2,0} parameter(0)
100     p1 = f32[5,3,4]{1,2,0} parameter(1)
101     ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
102       lhs_batch_dims={0}, lhs_contracting_dims={2},
103       rhs_batch_dims={0}, rhs_contracting_dims={1}
104   })";
105 
106   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
107                           ParseAndReturnVerifiedModule(hlo_text));
108 
109   ComputationLayout computation_layout(
110       module->entry_computation()->ComputeProgramShape(),
111       /*ignore_layouts=*/false);
112   GpuLayoutAssignment layout_assignment(&computation_layout,
113                                         backend().default_stream_executor());
114   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
115   EXPECT_THAT(module->entry_computation()->root_instruction(),
116               AllOf(op::Dot(op::ShapeWithLayout("f32[5,2,3]{1,2,0}"),
117                             op::ShapeWithLayout("f32[5,3,4]{1,2,0}")),
118                     op::ShapeWithLayout("f32[5,2,4]{2,1,0}")));
119 }
120 
TEST_F(LayoutAssignmentTest,DotLayoutSetToDefaultIfDefaultValid)121 TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) {
122   const char* hlo_text = R"(
123   HloModule DotLayout
124   ENTRY dot {
125     p0 = f32[5,3,2] parameter(0)
126     p1 = f32[5,4,3]{0,1,2} parameter(1)
127     ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
128       lhs_batch_dims={0}, lhs_contracting_dims={1},
129       rhs_batch_dims={0}, rhs_contracting_dims={2}
130   })";
131 
132   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
133                           ParseAndReturnVerifiedModule(hlo_text));
134 
135   ComputationLayout computation_layout(
136       module->entry_computation()->ComputeProgramShape(),
137       /*ignore_layouts=*/false);
138   GpuLayoutAssignment layout_assignment(&computation_layout,
139                                         backend().default_stream_executor());
140 
141   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
142   EXPECT_THAT(module->entry_computation()->root_instruction(),
143               AllOf(op::Dot(op::ShapeWithLayout("f32[5,3,2]{2,1,0}"),
144                             op::ShapeWithLayout("f32[5,4,3]{2,1,0}")),
145                     op::ShapeWithLayout("f32[5,2,4]{2,1,0}")));
146 }
147 
TEST_F(LayoutAssignmentTest,DotOperandLayoutSetToBatchRowsColsOtherwise)148 TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) {
149   const char* hlo_text = R"(
150   HloModule DotLayout
151   ENTRY dot {
152     p0 = f32[2,3,5]{2,1,0} parameter(0)
153     p1 = f32[3,4,5] parameter(1)
154     ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
155       lhs_batch_dims={2}, lhs_contracting_dims={1},
156       rhs_batch_dims={2}, rhs_contracting_dims={0}
157   })";
158 
159   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
160                           ParseAndReturnVerifiedModule(hlo_text));
161 
162   ComputationLayout computation_layout(
163       module->entry_computation()->ComputeProgramShape(),
164       /*ignore_layouts=*/false);
165   GpuLayoutAssignment layout_assignment(&computation_layout,
166                                         backend().default_stream_executor());
167 
168   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
169   EXPECT_THAT(module->entry_computation()->root_instruction(),
170               op::Dot(op::ShapeWithLayout("f32[2,3,5]{1,0,2}"),
171                       op::ShapeWithLayout("f32[3,4,5]{1,0,2}")));
172 }
173 
TEST_F(LayoutAssignmentTest,DotOperandInconsistentDimLayouts)174 TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) {
175   const char* hlo_text = R"(
176   HloModule DotLayout
177   ENTRY dot {
178     p0 = f32[5,6,2,3] parameter(0)
179     p1 = f32[6,5,3,4] parameter(1)
180     ROOT dot.1330.10585 = f32[5,6,2,4] dot(p0, p1),
181       lhs_batch_dims={0,1}, lhs_contracting_dims={3},
182       rhs_batch_dims={1,0}, rhs_contracting_dims={2}
183   })";
184 
185   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
186                           ParseAndReturnVerifiedModule(hlo_text));
187 
188   ComputationLayout computation_layout(
189       module->entry_computation()->ComputeProgramShape(),
190       /*ignore_layouts=*/false);
191   GpuLayoutAssignment layout_assignment(&computation_layout,
192                                         backend().default_stream_executor());
193 
194   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
195   EXPECT_THAT(module->entry_computation()->root_instruction(),
196               op::Dot(op::ShapeWithLayout("f32[5,6,2,3]{3,2,1,0}"),
197                       op::ShapeWithLayout("f32[6,5,3,4]{3,2,0,1}")));
198 }
199 
TEST_F(LayoutAssignmentTest,TransposedDotLayout)200 TEST_F(LayoutAssignmentTest, TransposedDotLayout) {
201   const char* hlo_text = R"(
202   HloModule DotLayout
203   ENTRY dot {
204     p0 = f32[5,2,3] parameter(0)
205     p1 = f32[5,3,4] parameter(1)
206     dot = f32[5,2,4] dot(p0, p1),
207       lhs_batch_dims={0}, lhs_contracting_dims={2},
208       rhs_batch_dims={0}, rhs_contracting_dims={1}
209     ROOT out = f32[2,5,4] transpose(dot), dimensions={1,0,2}
210   })";
211 
212   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
213                           ParseAndReturnVerifiedModule(hlo_text));
214 
215   ComputationLayout computation_layout(
216       module->entry_computation()->ComputeProgramShape(),
217       /*ignore_layouts=*/false);
218   GpuLayoutAssignment layout_assignment(&computation_layout,
219                                         backend().default_stream_executor());
220 
221   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
222   EXPECT_THAT(module->entry_computation()->root_instruction(),
223               AllOf(op::Transpose(
224                         AllOf(op::Dot(op::ShapeWithLayout("f32[5,2,3]{2,1,0}"),
225                                       op::ShapeWithLayout("f32[5,3,4]{2,1,0}")),
226                               op::ShapeWithLayout("f32[5,2,4]{2,0,1}"))),
227                     op::ShapeWithLayout("f32[2,5,4]{2,1,0}")));
228 }
229 
TEST_F(LayoutAssignmentTest,DotLayoutS8)230 TEST_F(LayoutAssignmentTest, DotLayoutS8) {
231   const char* hlo_text = R"(
232   HloModule DotLayout
233   ENTRY int8_t {
234     p0 = s8[32,64] parameter(0)
235     p1 = s8[64,96] parameter(1)
236     ROOT out = s32[32,96] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
237   })";
238 
239   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
240                           ParseAndReturnVerifiedModule(hlo_text));
241 
242   ComputationLayout computation_layout(
243       module->entry_computation()->ComputeProgramShape(),
244       /*ignore_layouts=*/false);
245   GpuLayoutAssignment layout_assignment(&computation_layout,
246                                         backend().default_stream_executor());
247 
248   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
249   EXPECT_THAT(module->entry_computation()->root_instruction(),
250               op::Dot(op::ShapeWithLayout("s8[32,64]{1,0}"),
251                       op::ShapeWithLayout("s8[64,96]{0,1}")));
252 }
253 
TEST_F(LayoutAssignmentTest,SortLayout)254 TEST_F(LayoutAssignmentTest, SortLayout) {
255   const char* hlo_text = R"(
256   HloModule SortLayout
257 
258   compare {
259     p.0.lhs = f32[] parameter(0)
260     p.0.rhs = f32[] parameter(1)
261     p.1.lhs = f32[] parameter(2)
262     p.1.rhs = f32[] parameter(3)
263     ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
264   }
265 
266   ENTRY sort {
267     keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
268     values = f32[2,3]{1,0} parameter(0)
269     transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
270     ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
271       dimensions={1}, to_apply=compare
272   })";
273 
274   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
275                           ParseAndReturnVerifiedModule(hlo_text));
276 
277   ComputationLayout computation_layout(
278       module->entry_computation()->ComputeProgramShape(),
279       /*ignore_layouts=*/false);
280   GpuLayoutAssignment layout_assignment(&computation_layout,
281                                         backend().default_stream_executor());
282 
283   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
284   EXPECT_THAT(module->entry_computation()->root_instruction(),
285               op::Sort(op::ShapeWithLayout("f32[3,2]{1,0}"),
286                        op::ShapeWithLayout("f32[3,2]{1,0}")));
287 }
288 
TEST_F(LayoutAssignmentTest,FftLayout)289 TEST_F(LayoutAssignmentTest, FftLayout) {
290   const char* hlo_text = R"(
291   HloModule Fft_module
292 
293   ENTRY Fft {
294     input = c64[8,32]{0,1} parameter(0)
295     fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
296     ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
297   })";
298 
299   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
300                           ParseAndReturnVerifiedModule(hlo_text));
301 
302   ComputationLayout computation_layout(
303       module->entry_computation()->ComputeProgramShape(),
304       /*ignore_layouts=*/false);
305   GpuLayoutAssignment layout_assignment(&computation_layout,
306                                         backend().default_stream_executor());
307 
308   EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
309   EXPECT_THAT(module->entry_computation()->root_instruction(),
310               op::Copy(op::Transpose(
311                   AllOf(op::Fft(op::ShapeWithLayout("c64[8,32]{1,0}")),
312                         op::ShapeWithLayout("c64[8,32]{1,0}")))));
313 }
314 
315 }  // namespace
316 }  // namespace gpu
317 }  // namespace xla
318