1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/compiler/xla/layout_util.h"
20 #include "tensorflow/compiler/xla/service/computation_layout.h"
21 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
22 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_layout.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/status_matchers.h"
34 #include "tensorflow/stream_executor/lib/statusor.h"
35
36 namespace xla {
37 namespace gpu {
38 namespace {
39
40 namespace op = xla::testing::opcode_matchers;
41 using ::tensorflow::testing::IsOkAndHolds;
42 using ::testing::AllOf;
43
44 using LayoutAssignmentTest = HloTestBase;
45
TEST_F(LayoutAssignmentTest,Elementwise)46 TEST_F(LayoutAssignmentTest, Elementwise) {
47 Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
48 Shape ashape_in_row_major(ashape);
49 Shape ashape_in_col_major(ashape);
50 *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
51 *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
52
53 // Enumerate all possible combinations of layouts.
54 for (const Shape& lhs_shape_with_layout :
55 {ashape_in_row_major, ashape_in_col_major}) {
56 for (const Shape& rhs_shape_with_layout :
57 {ashape_in_row_major, ashape_in_col_major}) {
58 for (const Shape& result_shape_with_layout :
59 {ashape_in_row_major, ashape_in_col_major}) {
60 // GpuLayoutAssignment should assign the same layout to "add" and its
61 // two operands.
62 auto builder = HloComputation::Builder(TestName());
63 auto x = builder.AddInstruction(
64 HloInstruction::CreateParameter(0, ashape, "x"));
65 auto y = builder.AddInstruction(
66 HloInstruction::CreateParameter(1, ashape, "y"));
67 auto add = builder.AddInstruction(
68 HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
69 auto module = CreateNewVerifiedModule();
70 HloComputation* computation =
71 module->AddEntryComputation(builder.Build(add));
72
73 ComputationLayout computation_layout(
74 computation->ComputeProgramShape());
75 *computation_layout.mutable_parameter_layout(0) =
76 ShapeLayout(lhs_shape_with_layout);
77 *computation_layout.mutable_parameter_layout(1) =
78 ShapeLayout(rhs_shape_with_layout);
79 *computation_layout.mutable_result_layout() =
80 ShapeLayout(result_shape_with_layout);
81
82 GpuLayoutAssignment layout_assignment(
83 &computation_layout, backend().default_stream_executor());
84 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
85
86 for (const HloInstruction* operand : add->operands()) {
87 EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
88 operand->shape().layout()));
89 }
90 }
91 }
92 }
93 }
94
TEST_F(LayoutAssignmentTest,DotLayoutUnchangedIfValid)95 TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) {
96 const char* hlo_text = R"(
97 HloModule DotLayout
98 ENTRY dot {
99 p0 = f32[5,2,3]{1,2,0} parameter(0)
100 p1 = f32[5,3,4]{1,2,0} parameter(1)
101 ROOT dot.1330.10585 = f32[5,2,4]{2,1,0} dot(p0, p1),
102 lhs_batch_dims={0}, lhs_contracting_dims={2},
103 rhs_batch_dims={0}, rhs_contracting_dims={1}
104 })";
105
106 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
107 ParseAndReturnVerifiedModule(hlo_text));
108
109 ComputationLayout computation_layout(
110 module->entry_computation()->ComputeProgramShape(),
111 /*ignore_layouts=*/false);
112 GpuLayoutAssignment layout_assignment(&computation_layout,
113 backend().default_stream_executor());
114 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
115 EXPECT_THAT(module->entry_computation()->root_instruction(),
116 AllOf(op::Dot(op::ShapeWithLayout("f32[5,2,3]{1,2,0}"),
117 op::ShapeWithLayout("f32[5,3,4]{1,2,0}")),
118 op::ShapeWithLayout("f32[5,2,4]{2,1,0}")));
119 }
120
TEST_F(LayoutAssignmentTest,DotLayoutSetToDefaultIfDefaultValid)121 TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) {
122 const char* hlo_text = R"(
123 HloModule DotLayout
124 ENTRY dot {
125 p0 = f32[5,3,2] parameter(0)
126 p1 = f32[5,4,3]{0,1,2} parameter(1)
127 ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
128 lhs_batch_dims={0}, lhs_contracting_dims={1},
129 rhs_batch_dims={0}, rhs_contracting_dims={2}
130 })";
131
132 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
133 ParseAndReturnVerifiedModule(hlo_text));
134
135 ComputationLayout computation_layout(
136 module->entry_computation()->ComputeProgramShape(),
137 /*ignore_layouts=*/false);
138 GpuLayoutAssignment layout_assignment(&computation_layout,
139 backend().default_stream_executor());
140
141 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
142 EXPECT_THAT(module->entry_computation()->root_instruction(),
143 AllOf(op::Dot(op::ShapeWithLayout("f32[5,3,2]{2,1,0}"),
144 op::ShapeWithLayout("f32[5,4,3]{2,1,0}")),
145 op::ShapeWithLayout("f32[5,2,4]{2,1,0}")));
146 }
147
TEST_F(LayoutAssignmentTest,DotOperandLayoutSetToBatchRowsColsOtherwise)148 TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) {
149 const char* hlo_text = R"(
150 HloModule DotLayout
151 ENTRY dot {
152 p0 = f32[2,3,5]{2,1,0} parameter(0)
153 p1 = f32[3,4,5] parameter(1)
154 ROOT dot.1330.10585 = f32[5,2,4] dot(p0, p1),
155 lhs_batch_dims={2}, lhs_contracting_dims={1},
156 rhs_batch_dims={2}, rhs_contracting_dims={0}
157 })";
158
159 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
160 ParseAndReturnVerifiedModule(hlo_text));
161
162 ComputationLayout computation_layout(
163 module->entry_computation()->ComputeProgramShape(),
164 /*ignore_layouts=*/false);
165 GpuLayoutAssignment layout_assignment(&computation_layout,
166 backend().default_stream_executor());
167
168 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
169 EXPECT_THAT(module->entry_computation()->root_instruction(),
170 op::Dot(op::ShapeWithLayout("f32[2,3,5]{1,0,2}"),
171 op::ShapeWithLayout("f32[3,4,5]{1,0,2}")));
172 }
173
TEST_F(LayoutAssignmentTest,DotOperandInconsistentDimLayouts)174 TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) {
175 const char* hlo_text = R"(
176 HloModule DotLayout
177 ENTRY dot {
178 p0 = f32[5,6,2,3] parameter(0)
179 p1 = f32[6,5,3,4] parameter(1)
180 ROOT dot.1330.10585 = f32[5,6,2,4] dot(p0, p1),
181 lhs_batch_dims={0,1}, lhs_contracting_dims={3},
182 rhs_batch_dims={1,0}, rhs_contracting_dims={2}
183 })";
184
185 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
186 ParseAndReturnVerifiedModule(hlo_text));
187
188 ComputationLayout computation_layout(
189 module->entry_computation()->ComputeProgramShape(),
190 /*ignore_layouts=*/false);
191 GpuLayoutAssignment layout_assignment(&computation_layout,
192 backend().default_stream_executor());
193
194 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
195 EXPECT_THAT(module->entry_computation()->root_instruction(),
196 op::Dot(op::ShapeWithLayout("f32[5,6,2,3]{3,2,1,0}"),
197 op::ShapeWithLayout("f32[6,5,3,4]{3,2,0,1}")));
198 }
199
TEST_F(LayoutAssignmentTest,TransposedDotLayout)200 TEST_F(LayoutAssignmentTest, TransposedDotLayout) {
201 const char* hlo_text = R"(
202 HloModule DotLayout
203 ENTRY dot {
204 p0 = f32[5,2,3] parameter(0)
205 p1 = f32[5,3,4] parameter(1)
206 dot = f32[5,2,4] dot(p0, p1),
207 lhs_batch_dims={0}, lhs_contracting_dims={2},
208 rhs_batch_dims={0}, rhs_contracting_dims={1}
209 ROOT out = f32[2,5,4] transpose(dot), dimensions={1,0,2}
210 })";
211
212 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
213 ParseAndReturnVerifiedModule(hlo_text));
214
215 ComputationLayout computation_layout(
216 module->entry_computation()->ComputeProgramShape(),
217 /*ignore_layouts=*/false);
218 GpuLayoutAssignment layout_assignment(&computation_layout,
219 backend().default_stream_executor());
220
221 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
222 EXPECT_THAT(module->entry_computation()->root_instruction(),
223 AllOf(op::Transpose(
224 AllOf(op::Dot(op::ShapeWithLayout("f32[5,2,3]{2,1,0}"),
225 op::ShapeWithLayout("f32[5,3,4]{2,1,0}")),
226 op::ShapeWithLayout("f32[5,2,4]{2,0,1}"))),
227 op::ShapeWithLayout("f32[2,5,4]{2,1,0}")));
228 }
229
TEST_F(LayoutAssignmentTest,DotLayoutS8)230 TEST_F(LayoutAssignmentTest, DotLayoutS8) {
231 const char* hlo_text = R"(
232 HloModule DotLayout
233 ENTRY int8_t {
234 p0 = s8[32,64] parameter(0)
235 p1 = s8[64,96] parameter(1)
236 ROOT out = s32[32,96] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
237 })";
238
239 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
240 ParseAndReturnVerifiedModule(hlo_text));
241
242 ComputationLayout computation_layout(
243 module->entry_computation()->ComputeProgramShape(),
244 /*ignore_layouts=*/false);
245 GpuLayoutAssignment layout_assignment(&computation_layout,
246 backend().default_stream_executor());
247
248 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
249 EXPECT_THAT(module->entry_computation()->root_instruction(),
250 op::Dot(op::ShapeWithLayout("s8[32,64]{1,0}"),
251 op::ShapeWithLayout("s8[64,96]{0,1}")));
252 }
253
TEST_F(LayoutAssignmentTest,SortLayout)254 TEST_F(LayoutAssignmentTest, SortLayout) {
255 const char* hlo_text = R"(
256 HloModule SortLayout
257
258 compare {
259 p.0.lhs = f32[] parameter(0)
260 p.0.rhs = f32[] parameter(1)
261 p.1.lhs = f32[] parameter(2)
262 p.1.rhs = f32[] parameter(3)
263 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
264 }
265
266 ENTRY sort {
267 keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
268 values = f32[2,3]{1,0} parameter(0)
269 transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
270 ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
271 dimensions={1}, to_apply=compare
272 })";
273
274 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
275 ParseAndReturnVerifiedModule(hlo_text));
276
277 ComputationLayout computation_layout(
278 module->entry_computation()->ComputeProgramShape(),
279 /*ignore_layouts=*/false);
280 GpuLayoutAssignment layout_assignment(&computation_layout,
281 backend().default_stream_executor());
282
283 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
284 EXPECT_THAT(module->entry_computation()->root_instruction(),
285 op::Sort(op::ShapeWithLayout("f32[3,2]{1,0}"),
286 op::ShapeWithLayout("f32[3,2]{1,0}")));
287 }
288
TEST_F(LayoutAssignmentTest,FftLayout)289 TEST_F(LayoutAssignmentTest, FftLayout) {
290 const char* hlo_text = R"(
291 HloModule Fft_module
292
293 ENTRY Fft {
294 input = c64[8,32]{0,1} parameter(0)
295 fft = c64[8,32] fft(input), fft_type=FFT, fft_length={32}
296 ROOT transpose = c64[32,8] transpose(fft), dimensions={1,0}
297 })";
298
299 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
300 ParseAndReturnVerifiedModule(hlo_text));
301
302 ComputationLayout computation_layout(
303 module->entry_computation()->ComputeProgramShape(),
304 /*ignore_layouts=*/false);
305 GpuLayoutAssignment layout_assignment(&computation_layout,
306 backend().default_stream_executor());
307
308 EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true));
309 EXPECT_THAT(module->entry_computation()->root_instruction(),
310 op::Copy(op::Transpose(
311 AllOf(op::Fft(op::ShapeWithLayout("c64[8,32]{1,0}")),
312 op::ShapeWithLayout("c64[8,32]{1,0}")))));
313 }
314
315 } // namespace
316 } // namespace gpu
317 } // namespace xla
318