• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <utility>
19 
20 #include "absl/strings/str_replace.h"
21 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
22 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/platform/test.h"
25 
26 namespace xla {
27 namespace gpu {
28 namespace {
29 
30 class GpuKernelTilingTest : public GpuCodegenTest {
31  protected:
GpuKernelTilingTest()32   GpuKernelTilingTest() {}
33 
MakePlatformSpecific(absl::string_view input)34   std::string MakePlatformSpecific(absl::string_view input) {
35     return absl::StrReplaceAll(
36         input,
37         {{"KERNEL_ANNOTATION",
38           is_built_with_rocm_ ? "amdgpu_kernel void" : "void"},
39          {"BARRIER", is_built_with_rocm_ ? "@llvm.amdgcn.s.barrier"
40                                          : "@llvm.nvvm.barrier0"},
41          {"SHUFFLE", is_built_with_rocm_
42                          ? "i32 @llvm.amdgcn.ds.bpermute"
43                          : "float @llvm.nvvm.shfl.sync.down.f32"},
44          {"TIDX", is_built_with_rocm_ ? "llvm.amdgcn.workitem.id.x"
45                                       : "@llvm.nvvm.read.ptx.sreg.tid.x"}});
46   }
47 
48   // Most tests in this file want to skip layout assignment, but a few need it
49   // enabled.
ConfigWithLayoutAssignment()50   HloModuleConfig ConfigWithLayoutAssignment() {
51     return GetModuleConfigForTest();
52   }
53 
ConfigWithoutLayoutAssignment()54   HloModuleConfig ConfigWithoutLayoutAssignment() {
55     HloModuleConfig config;
56     auto debug_options = HloTestBase::GetDebugOptionsForTest();
57     // Disable layout_assignment to use the preassigned layouts.
58     debug_options.add_xla_disable_hlo_passes("layout-assignment");
59     config.set_debug_options(debug_options);
60     return config;
61   }
62 };
63 
TEST_F(GpuKernelTilingTest,UnnestedTransposeWithProperDimensionsTiled)64 TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) {
65   const char *const kHloString = R"(
66     HloModule unnested_transpose_1
67 
68     ENTRY unnested_transpose_1 {
69       para0 = f16[32,3,64]{2,1,0} parameter(0)
70       ROOT copy1 = f16[32,3,64]{1,0,2} copy(para0)
71     })";
72 
73   // Check that a call to llvm.nvvm.barrier0 is generated.
74   //
75   // We must enable layout assignment in order for this test to work correctly.
76   // AlgebraicSimplifier removes copy1; it's added back by layout assignment,
77   // which respects the module's entry computation layout.  But if we don't run
78   // layout assignment...well, nobody else adds the copy back.
79   auto hlo_module =
80       ParseAndReturnVerifiedModule(kHloString, ConfigWithLayoutAssignment())
81           .ValueOrDie();
82 
83   auto expected_ir = R"(
84 ; CHECK-LABEL: define KERNEL_ANNOTATION @copy
85 ; CHECK: call void BARRIER()
86 ; CHECK: }
87 )";
88   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
89                      /*match_optimized_ir=*/true);
90 
91   // Check that the kernel runs correctly.
92   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
93 }
94 
TEST_F(GpuKernelTilingTest,UnnestedTransposeWithSmallDimensionsNotTiled)95 TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) {
96   const char *const kHloString = R"(
97     HloModule unnested_transpose_2
98 
99     ENTRY unnested_transpose_2 {
100       para0 = f16[2,3,64]{2,1,0} parameter(0)
101       ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0)
102     })";
103 
104   // Check that a call to llvm.nvvm.barrier0 is not generated.  As in
105   // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment
106   // here.
107   auto hlo_module =
108       ParseAndReturnVerifiedModule(kHloString, ConfigWithLayoutAssignment())
109           .ValueOrDie();
110   auto expected_ir = R"(
111 ; CHECK-LABEL: define KERNEL_ANNOTATION @copy
112 ; CHECK-NOT: call void BARRIER()
113 ; CHECK: }
114 )";
115   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
116                      /*match_optimized_ir=*/true);
117 }
118 
TEST_F(GpuKernelTilingTest,UnnestedTransposeC128TypeRun)119 TEST_F(GpuKernelTilingTest, UnnestedTransposeC128TypeRun) {
120   const char *const kHloString = R"(
121     HloModule unnested_transpose_3
122 
123     ENTRY unnested_transpose_3 {
124       para0 = c128[65,65]{1,0} parameter(0)
125       ROOT copy1 = c128[65,65]{0,1} copy(para0)
126     })";
127 
128   // With the current implementation for the available hardwares, we bail out
129   // from the tiled transpose implementation at the last minute. Instead of
130   // checking the transpose is not tiled, we only check the module compiled and
131   // run in this test.
132   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
133 }
134 
TEST_F(GpuKernelTilingTest,SimpleFusionWithTransposeTiled)135 TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) {
136   const char *const kHloString = R"(
137     HloModule multiple_output_fusion_1
138     fused_computation.1 {
139       param0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0)
140       copy = f32[4,5,6,7,8]{2,1,4,3,0} copy(param0)
141       ROOT convert = f16[4,5,6,7,8]{2,1,4,3,0} convert(copy)
142     }
143 
144     ENTRY copy_in_fusion_run_without_hlo_passes {
145       para0 = f32[4,5,6,7,8]{4,3,2,1,0} parameter(0)
146       ROOT fusion.1 = f16[4,5,6,7,8]{2,1,4,3,0} fusion(para0), kind=kLoop,
147         calls=fused_computation.1
148     })";
149 
150   // Check that a call to llvm.nvvm.barrier0 is generated.
151   auto hlo_module =
152       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
153           .ValueOrDie();
154   auto expected_ir = R"(
155 ; CHECK-LABEL: define KERNEL_ANNOTATION @fusion
156 ; CHECK: call void BARRIER()
157 ; CHECK: }
158 )";
159   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
160                      /*match_optimized_ir=*/true);
161 
162   // Check that the kernel runs correctly.
163   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
164 }
165 
TEST_F(GpuKernelTilingTest,MultipleOutputFusionWithOnePossibleTransposeTiled)166 TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) {
167   const char *const kHloString = R"(
168     HloModule multiple_output_fusion_1
169     fused_computation.1 {
170       param0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
171       param1 = f16[8,31,31,65]{3,2,1,0} parameter(1)
172       copy0 = f16[8,31,31,65]{2,1,3,0} copy(param0)
173       copy1 = f16[8,31,31,65]{2,1,3,0} copy(param1)
174       ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
175         tuple(copy0, copy1)
176     }
177 
178     ENTRY multiple_output_fusion_1 {
179       para0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
180       para1 = f16[8,31,31,65]{3,2,1,0} parameter(1)
181       ROOT fusion.1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
182         fusion(para0,para1), kind=kLoop, calls=fused_computation.1
183     })";
184 
185   // Check that a call to llvm.nvvm.barrier0 is generated.
186   auto hlo_module =
187       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
188           .ValueOrDie();
189   auto expected_ir = R"(
190 ; CHECK-LABEL: define KERNEL_ANNOTATION @fusion
191 ; CHECK: call void BARRIER()
192 ; CHECK: }
193 )";
194   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
195                      /*match_optimized_ir=*/true);
196 
197   // Check that the kernel runs correctly.
198   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
199 }
200 
TEST_F(GpuKernelTilingTest,MultipleOutputFusionWithTwoPossibleTransposesNotTiled)201 TEST_F(GpuKernelTilingTest,
202        MultipleOutputFusionWithTwoPossibleTransposesNotTiled) {
203   const char *const kHloString = R"(
204     HloModule multiple_output_fusion_2
205     fused_computation.1 {
206       param0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
207       param1 = f16[8,31,31,65]{1,3,2,0} parameter(1)
208       copy2 = f16[8,31,31,65]{2,1,3,0} copy(param0)
209       copy3 = f16[8,31,31,65]{2,1,3,0} copy(param1)
210       ROOT tuple1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
211         tuple(copy2, copy3)
212     }
213 
214     ENTRY multiple_output_fusion_2 {
215       para0 = f16[8,31,31,65]{3,2,1,0} parameter(0)
216       para1 = f16[8,31,31,65]{1,3,2,0} parameter(1)
217       ROOT fusion1 = (f16[8,31,31,65]{2,1,3,0}, f16[8,31,31,65]{2,1,3,0})
218         fusion(para0,para1), kind=kLoop, calls=fused_computation.1
219     })";
220 
221   // Check that a call to llvm.nvvm.barrier0 is not generated.
222   auto hlo_module =
223       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
224           .ValueOrDie();
225   auto expected_ir = R"(
226 ; CHECK-LABEL: define KERNEL_ANNOTATION @fusion
227 ; CHECK-NOT: call void BARRIER()
228 ; CHECK: }
229 )";
230   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
231                      /*match_optimized_ir=*/true);
232 }
233 
TEST_F(GpuKernelTilingTest,TransposedInputWithUserReverseNotTiled)234 TEST_F(GpuKernelTilingTest, TransposedInputWithUserReverseNotTiled) {
235   const char *const kHloString = R"(
236     HloModule FusionTransposeWithReverseNotTiled
237     fused_computation.1 {
238       arg0 = f32[128,64]{1,0} parameter(0)
239       copy0 = f32[128,64]{0,1} copy(arg0)
240       ROOT reverse0 = f32[128,64]{0,1} reverse(copy0), dimensions={0}
241     }
242 
243     ENTRY reverse_break_assumption {
244       param0 = f32[128,64]{1,0} parameter(0)
245       ROOT fusion0 = f32[128,64]{0,1} fusion(param0), kind=kLoop,
246         calls=fused_computation.1
247     })";
248 
249   // Check that a call to llvm.nvvm.barrier0 is not generated.
250   auto hlo_module =
251       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
252           .ValueOrDie();
253   auto expected_ir = R"(
254 ; CHECK-LABEL: define KERNEL_ANNOTATION @fusion
255 ; CHECK-NOT: call void BARRIER()
256 ; CHECK: }
257 )";
258   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
259                      /*match_optimized_ir=*/true);
260 }
261 
TEST_F(GpuKernelTilingTest,TransposedInputWithUserBitcastNotTiled)262 TEST_F(GpuKernelTilingTest, TransposedInputWithUserBitcastNotTiled) {
263   const char *const kHloString = R"(
264     HloModule TransposedInputWithUserBitcast
265 
266     fused_computation {
267       param_0 = f32[20,20]{1,0} parameter(0)
268       ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
269     }
270 
271     ENTRY kernel_entry {
272       parameter.0 = f32[20,20]{1,0} parameter(0)
273       ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
274         kind=kLoop, calls=fused_computation
275     })";
276 
277   // Check that a call to llvm.nvvm.barrier0 is not generated.
278   auto hlo_module =
279       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
280           .ValueOrDie();
281   auto expected_ir = R"(
282 ; CHECK-LABEL: define KERNEL_ANNOTATION @fusion
283 ; CHECK-NOT: call void BARRIER()
284 ; CHECK: }
285 )";
286   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
287                      /*match_optimized_ir=*/true);
288 
289   // Check that the kernel runs correctly.
290   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
291 }
292 
TEST_F(GpuKernelTilingTest,TransposedInputWithoutUnsafeUseTiled)293 TEST_F(GpuKernelTilingTest, TransposedInputWithoutUnsafeUseTiled) {
294   const char *const kHloString = R"(
295     HloModule TwoTransposedInputs
296 
297     fused_computation {
298       param_0 = f32[64,64]{1,0} parameter(0)
299       param_1 = f32[64,64]{1,0} parameter(1)
300       bitcast = f32[64,64]{0,1} bitcast(param_0)
301       copy = f32[64,64]{0,1} copy(param_1)
302       ROOT tuple = (f32[64,64]{0,1}, f32[64,64]{0,1}) tuple(bitcast, copy)
303     }
304 
305     ENTRY kernel_entry {
306       parameter.0 = f32[64,64]{1,0} parameter(0)
307       parameter.1 = f32[64,64]{1,0} parameter(1)
308       ROOT fusion = (f32[64,64]{0,1}, f32[64,64]{0,1})
309         fusion(parameter.0, parameter.1),
310         kind=kLoop, calls=fused_computation
311     })";
312 
313   // Check that a call to llvm.nvvm.barrier0 is generated.
314   auto hlo_module =
315       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
316           .ValueOrDie();
317   auto expected_ir = R"(
318 ; CHECK-LABEL: define KERNEL_ANNOTATION @fusion
319 ; CHECK: call void BARRIER()
320 ; CHECK: }
321 )";
322   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
323                      /*match_optimized_ir=*/true);
324   // Check that the kernel runs correctly.
325   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0}));
326 }
327 
TEST_F(GpuKernelTilingTest,ColumnReductionWithPowerOf2OutputElementsUnrolled)328 TEST_F(GpuKernelTilingTest, ColumnReductionWithPowerOf2OutputElementsUnrolled) {
329   const char *const kHloString = R"(
330   HloModule column_reduce_powerof2
331 
332   reduction {
333     x = f32[] parameter(0)
334     y = f32[] parameter(1)
335     ROOT add = f32[] add(x, y)
336   }
337 
338   ENTRY kernel_entry {
339     constant0 = f32[] constant(0)
340     arg1 = f16[1024,512]{1,0} parameter(0)
341     arg1_conv = f32[1024,512]{1,0} convert(arg1)
342     ROOT reduce = f32[512]{0} reduce(arg1_conv, constant0), dimensions={0}, to_apply=reduction
343   })";
344 
345   // Check that two calls to llvm.nvvm.atomic are generated.
346   auto hlo_module =
347       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
348           .ValueOrDie();
349   const char *expected_ir = R"(
350 ; CHECK: store float %{{.*}}, ptr addrspace(1)
351 ; CHECK: store float %{{.*}}, ptr addrspace(1)
352 )";
353   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
354                      /*match_optimized_ir=*/true);
355   // Check that the kernel runs correctly.
356   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5}));
357 }
358 
TEST_F(GpuKernelTilingTest,ColumnReductionWithInputLargerThenReduceInputNotUnrolled)359 TEST_F(GpuKernelTilingTest,
360        ColumnReductionWithInputLargerThenReduceInputNotUnrolled) {
361   const char *const kHloString = R"(
362   HloModule larger_than_reduce_input_parameter
363 
364   reduction22 {
365     x = f32[] parameter(0)
366     y = f32[] parameter(1)
367     ROOT add = f32[] add(x, y)
368   }
369 
370   fused_computation {
371     constant0 = f32[] constant(0)
372     arg.1 = f16[1024,512]{1,0} parameter(0)
373     arg.2 = f16[1027,513]{1,0} parameter(1)
374     arg1.conv = f32[1024,512]{1,0} convert(arg.1)
375     arg2.conv = f32[1027,513]{1,0} convert(arg.2)
376     slice2 = f32[1024,512]{1,0} slice(arg2.conv), slice={[2:1026], [1:513]}
377     add2 = f32[1024,512]{1,0} add(arg1.conv, slice2)
378     ROOT reduce = f32[512]{0} reduce(add2, constant0), dimensions={0},
379       to_apply=reduction22
380   }
381 
382   ENTRY kernel_entry {
383     arg1 = f16[1024,512]{1,0} parameter(0)
384     arg2 = f16[1027,513]{1,0} parameter(1)
385     ROOT fusion = f32[512]{0} fusion(arg1, arg2), kind=kInput,
386       calls=fused_computation
387   })";
388 
389   // Check that one call to llvm.nvvm.atomic is generated.
390   auto hlo_module =
391       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
392           .ValueOrDie();
393   const char *expected_ir = R"(
394 ; CHECK: store float %{{.*}}, ptr addrspace(1)
395 ; CHECK-NOT: store float %{{.*}}, ptr addrspace(1)
396 )";
397   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
398                      /*match_optimized_ir=*/true);
399   // Check that the kernel runs correctly.
400   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5}));
401 }
402 
TEST_F(GpuKernelTilingTest,ColumnReductionMOFUnrolled)403 TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) {
404   const char *const kHloString = R"(
405   HloModule column_reduce_powerof2_mof
406 
407   reduction22 {
408     x = f32[] parameter(0)
409     y = f32[] parameter(1)
410     ROOT add = f32[] add(x, y)
411   }
412 
413   fused_computation {
414     constant0 = f32[] constant(0)
415     arg.1 = f16[1024,512]{1,0} parameter(0)
416     arg.2 = f16[1024,512]{1,0} parameter(1)
417     arg1.conv = f32[1024,512]{1,0} convert(arg.1)
418     arg2.conv = f32[1024,512]{1,0} convert(arg.2)
419     reduce1 = f32[512]{0} reduce(arg1.conv, constant0), dimensions={0},
420       to_apply=reduction22
421     reduce2 = f32[512]{0} reduce(arg2.conv, constant0), dimensions={0},
422       to_apply=reduction22
423     add = f32[1024,512]{1,0} add(arg1.conv, arg2.conv)
424     ROOT tuple = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0})
425       tuple(reduce1, reduce2, add)
426   }
427 
428   ENTRY kernel_entry {
429     arg1 = f16[1024,512]{1,0} parameter(0)
430     arg2 = f16[1024,512]{1,0} parameter(1)
431     ROOT fusion = (f32[512]{0}, f32[512]{0}, f32[1024,512]{1,0})
432       fusion(arg1, arg2), kind=kInput, calls=fused_computation
433   })";
434 
435   // Check that four calls to llvm.nvvm.atomic are generated.
436   std::unique_ptr<VerifiedHloModule> hlo_module =
437       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
438           .ValueOrDie();
439   const char *expected_ir = R"(
440 ; CHECK-LABEL: define KERNEL_ANNOTATION @fusion
441 ; CHECK: store float %{{.*}}, ptr addrspace(1)
442 ; CHECK: store float %{{.*}}, ptr addrspace(1)
443 ; CHECK: store float %{{.*}}, ptr addrspace(1)
444 ; CHECK: store float %{{.*}}, ptr addrspace(1)
445 ; CHECK-NOT: store float %{{.*}}, ptr addrspace(1)
446 )";
447   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
448                      /*match_optimized_ir=*/true);
449   // Check that the kernel runs correctly.
450   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5}));
451 }
452 
TEST_F(GpuKernelTilingTest,ColumnReductionWithLayoutChangeTiled)453 TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) {
454   const char *const kHloString = R"(
455     HloModule reduce_with_layout_change
456     reduction0 {
457       x0 = f32[] parameter(0)
458       y0 = f32[] parameter(1)
459       ROOT add0 = f32[] add(x0, y0)
460     }
461 
462     ENTRY kernel_entry {
463       arg0 = f32[4,32,32,16,12,12,3,3]{2,3,5,4,0,7,6,1}  parameter(0)
464       constant0 = f32[] constant(0)
465       ROOT reduce0 = f32[4,32,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
466         dimensions={1,6,7}, to_apply=reduction0
467     })";
468 
469   // Check that the kernel is tiled by looking for llvm.nvvm.atomic.
470   auto hlo_module =
471       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
472           .ValueOrDie();
473   const char *expected_ir = R"(
474 ; CHECK-LABEL: define KERNEL_ANNOTATION @
475 ; CHECK: store float %{{.*}}, ptr addrspace(1)
476 ; CHECK: }
477 )";
478   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
479                      /*match_optimized_ir=*/true);
480 
481   // Check that the kernel runs correctly.
482   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
483 }
484 
TEST_F(GpuKernelTilingTest,RowReductionWithLayoutChangeTiled)485 TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) {
486   const char *const kHloString = R"(
487     HloModule reduce_with_layout_change
488     reduction0 {
489       x0 = f32[] parameter(0)
490       y0 = f32[] parameter(1)
491       ROOT add0 = f32[] add(x0, y0)
492     }
493 
494     ENTRY kernel_entry {
495       arg0 = f32[8,6,64]{2,1,0}  parameter(0)
496       constant0 = f32[] constant(0)
497       ROOT reduce0 = f32[8,6]{0,1} reduce(arg0, constant0), dimensions={2},
498         to_apply=reduction0
499     })";
500 
501   // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down.
502   auto hlo_module =
503       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
504           .ValueOrDie();
505   auto expected_ir = R"(
506 ; CHECK-LABEL: define KERNEL_ANNOTATION @reduce
507 ; CHECK: call SHUFFLE
508 ; CHECK: }
509 )";
510   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
511                      /*match_optimized_ir=*/true);
512 
513   // Check that the kernel runs correctly.
514   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
515 }
516 
TEST_F(GpuKernelTilingTest,RowReductionTwoRowsPerWarp)517 TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) {
518   const char *const kHloString = R"(
519     HloModule reduce_with_layout_change
520     reduction0 {
521       x0 = f32[] parameter(0)
522       y0 = f32[] parameter(1)
523       ROOT add0 = f32[] add(x0, y0)
524     }
525 
526     ENTRY kernel_entry {
527       arg0 = f32[10000,16]{1,0}  parameter(0)
528       constant0 = f32[] constant(0)
529       ROOT reduce0 = f32[10000]{0} reduce(arg0, constant0), dimensions={1},
530         to_apply=reduction0
531     })";
532 
533   // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down and
534   // a write condition based on the logical thread ID (two writes per warp).
535   auto hlo_module =
536       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
537           .ValueOrDie();
538   auto expected_ir = R"(
539 ; CHECK-LABEL: define KERNEL_ANNOTATION @reduce
540 ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX()
541 ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 15
542 ; CHECK: call SHUFFLE
543 ; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0
544 ; CHECK: br i1 %[[LOGICAL_T0]],
545 )";
546   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
547                      /*match_optimized_ir=*/true);
548 
549   // Check that the kernel runs correctly.
550   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
551 }
552 
TEST_F(GpuKernelTilingTest,RowReductionFourRowsPerWarp)553 TEST_F(GpuKernelTilingTest, RowReductionFourRowsPerWarp) {
554   const char *const kHloString = R"(
555     HloModule reduce_with_layout_change
556     reduction0 {
557       x0 = f32[] parameter(0)
558       y0 = f32[] parameter(1)
559       ROOT add0 = f32[] add(x0, y0)
560     }
561 
562     ENTRY kernel_entry {
563       arg0 = f32[10000,8]{1,0}  parameter(0)
564       constant0 = f32[] constant(0)
565       ROOT reduce0 = f32[10000]{0} reduce(arg0, constant0), dimensions={1},
566         to_apply=reduction0
567     })";
568 
569   // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down and
570   // a write condition based on the logical thread ID (four writes per warp).
571   auto hlo_module =
572       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
573           .ValueOrDie();
574   auto expected_ir = R"(
575 ; CHECK-LABEL: define KERNEL_ANNOTATION @reduce
576 ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX()
577 ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 7
578 ; CHECK: call SHUFFLE
579 ; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0
580 ; CHECK: br i1 %[[LOGICAL_T0]],
581 )";
582   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
583                      /*match_optimized_ir=*/true);
584 
585   // Check that the kernel runs correctly.
586   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
587 }
588 
TEST_F(GpuKernelTilingTest,ColumnReductionResultTwoPartsWithLayoutChangeTiled)589 TEST_F(GpuKernelTilingTest,
590        ColumnReductionResultTwoPartsWithLayoutChangeTiled) {
591   const char *const kHloString = R"(
592     HloModule reduce_with_no_layout_change
593     reduction0 {
594       x0 = f32[] parameter(0)
595       y0 = f32[] parameter(1)
596       ROOT add0 = f32[] add(x0, y0)
597     }
598 
599     ENTRY kernel_entry {
600       arg0 = f32[8,64,32]{2,1,0}  parameter(0)
601       constant0 = f32[] constant(0)
602       ROOT reduce0 = f32[8,32]{0,1} reduce(arg0, constant0), dimensions={1},
603         to_apply=reduction0
604     })";
605 
606   // Check that the kernel is tiled by looking for llvm.nvvm.atomic.
607   auto hlo_module =
608       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
609           .ValueOrDie();
610   const char *expected_ir = R"(
611 ; CHECK-LABEL: define KERNEL_ANNOTATION @reduce
612 ; CHECK: store float %{{.*}}, ptr addrspace(1)
613 ; CHECK: }
614 )";
615   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
616                      /*match_optimized_ir=*/true);
617 
618   // Check that the kernel runs correctly.
619   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
620 }
621 
TEST_F(GpuKernelTilingTest,ColumnReductionSmallTileSizeX)622 TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) {
623   const char *const kHloString = R"(
624   HloModule Test
625 
626   scalar_add_computation.1 {
627     scalar_lhs.1 = f32[] parameter(0)
628     scalar_rhs.1 = f32[] parameter(1)
629     ROOT add.6 = f32[] add(scalar_lhs.1, scalar_rhs.1)
630   }
631   ENTRY Test {
632     param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3)
633     constant_661 = f16[] constant(0)
634     broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(constant_661), dimensions={}
635     compare.42 = pred[512,2,9,9]{1,3,2,0} compare(param_3.241, broadcast.695), direction=GT
636     param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2)
637     select.40 = f16[512,2,9,9]{1,3,2,0} select(compare.42, param_2.401, broadcast.695)
638     convert.196 = f32[512,2,9,9]{1,3,2,0} convert(select.40)
639     param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1)
640     copy.335 = f16[512,2,9,9]{1,3,2,0} copy(param_1.809)
641     convert.218 = f32[512,2,9,9]{1,3,2,0} convert(copy.335)
642     param_0.668 = f32[2]{0} parameter(0)
643     broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(param_0.668), dimensions={1}
644     subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(convert.218, broadcast.687)
645     multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(convert.196, subtract.136)
646     constant_485 = f32[] constant(0)
647     reduce.139 = f32[2]{0} reduce(multiply.579, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1
648     reduce.140.clone.1 = f32[2]{0} reduce(convert.196, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1
649     ROOT tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(reduce.139, reduce.140.clone.1)
650   })";
651 
652   // Check that no loop is generated for reduction.
653   auto hlo_module =
654       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
655           .ValueOrDie();
656   const char *expected_ir = R"(
657 ; CHECK-NOT: reduce.0.loop_header
658 ; CHECK: }
659 )";
660   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
661                      /*match_optimized_ir=*/true);
662   // Check that the kernel runs correctly.
663   EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{1.0e-5, 1.0e-5}));
664 }
665 
TEST_F(GpuKernelTilingTest,RowReductionWithSmallNonPowerOfTwoDimensionNotTiled)666 TEST_F(GpuKernelTilingTest,
667        RowReductionWithSmallNonPowerOfTwoDimensionNotTiled) {
668   const char *const kHloString = R"(
669     HloModule reduction
670     reduction0 {
671       x0 = f32[] parameter(0)
672       y0 = f32[] parameter(1)
673       ROOT add0 = f32[] add(x0, y0)
674     }
675 
676     ENTRY kernel_entry {
677       arg0 = f32[8,6,15]{2,1,0}  parameter(0)
678       constant0 = f32[] constant(0)
679       ROOT reduce0 = f32[8,6]{1,0} reduce(arg0, constant0), dimensions={2},
680         to_apply=reduction0
681     })";
682 
683   // Check that the kernel is not tiled by looking for llvm.nvvm.shfl.sync.down.
684   auto hlo_module =
685       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
686           .ValueOrDie();
687   auto expected_ir = R"(
688 ; CHECK-LABEL: define KERNEL_ANNOTATION @reduce
689 ; CHECK-NOT: call SHUFFLE
690 ; CHECK: }
691 )";
692   CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecific(expected_ir),
693                      /*match_optimized_ir=*/true);
694 
695   // Check that the kernel runs correctly.
696   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
697 }
698 
TEST_F(GpuKernelTilingTest,RowReductionRequiring64BitIndex)699 TEST_F(GpuKernelTilingTest, RowReductionRequiring64BitIndex) {
700   const char *const kHloString = R"(
701   HloModule LargeReduction
702 
703   Sum {
704     x.1 = f32[] parameter(0)
705     y.1 = f32[] parameter(1)
706     ROOT add.1 = f32[] add(x.1, y.1)
707   }
708 
709   ENTRY reduce.1 {
710     parameter = f32[3048576000] parameter(0)
711     init_value = f32[] constant(0)
712     ROOT out = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum
713   }
714   )";
715   std::unique_ptr<VerifiedHloModule> hlo_module =
716       ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
717   const char *expected_ir = R"(
718 ; CHECK: i64
719   )";
720   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
721                      /*match_optimized_ir=*/true);
722 }
723 
TEST_F(GpuKernelTilingTest,ColumnReductionVectorization)724 TEST_F(GpuKernelTilingTest, ColumnReductionVectorization) {
725   const char *const kHloString = R"(
726 HloModule column_reduce_powerof2
727 
728 reduction {
729     x = f32[] parameter(0)
730     y = f32[] parameter(1)
731     ROOT add = f32[] add(x, y)
732 }
733 
734 ENTRY kernel_entry {
735     constant0 = f32[] constant(0)
736     arg1 = f32[1024,512]{1,0} parameter(0)
737     ROOT reduce = f32[512]{0} reduce(arg1, constant0), dimensions={0}, to_apply=reduction
738 }
739   )";
740   auto expected_ir = R"(
741 ; CHECK: load <2 x float>, ptr
742   )";
743   auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
744   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
745                      /*match_optimized_ir=*/true);
746 }
747 
TEST_F(GpuKernelTilingTest,Hlo021CopyNoOobAccess)748 TEST_F(GpuKernelTilingTest, Hlo021CopyNoOobAccess) {
749   const char *const kHloString = R"(
750 HloModule primitive_computation_svd.38
751 
752 %fused_computation (param_0.7: f32[3,29,29], param_1.10: pred[3]) -> f32[3,29,29] {
753   %param_1.10 = pred[3]{0} parameter(1)
754   %broadcast.7 = pred[3,29,29]{2,1,0} broadcast(pred[3]{0} %param_1.10), dimensions={0}
755   %param_0.7 = f32[3,29,29]{1,2,0} parameter(0)
756   %copy.6 = f32[3,29,29]{2,1,0} copy(f32[3,29,29]{1,2,0} %param_0.7)
757   %constant_1 = f32[] constant(nan)
758   %broadcast.6 = f32[3,29,29]{2,1,0} broadcast(f32[] %constant_1), dimensions={}
759   ROOT %select.0 = f32[3,29,29]{2,1,0} select(pred[3,29,29]{2,1,0} %broadcast.7, f32[3,29,29]{2,1,0} %copy.6, f32[3,29,29]{2,1,0} %broadcast.6)
760 }
761 
762 ENTRY %primitive_computation_svd.38 (constant_5: f32[3,29,29], fusion.3: pred[3]) -> f32[3,29,29] {
763   %constant_5 = f32[3,29,29]{1,2,0} parameter(0)
764   %fusion.3 = pred[3]{0} parameter(1)
765   ROOT %fusion = f32[3,29,29]{2,1,0} fusion(f32[3,29,29]{1,2,0} %constant_5, pred[3]{0} %fusion.3), kind=kLoop, calls=%fused_computation
766 }
767   )";
768 
769   // Test against the OOB read due to a ptxas bug.
770   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
771 }
772 
TEST_F(GpuKernelTilingTest,RowReductionCorrectShmemUsage)773 TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) {
774   const char *const kHloString = R"(
775   HloModule RowReduce
776 
777   Sum {
778     x.1 = f32[] parameter(0)
779     y.1 = f32[] parameter(1)
780     ROOT add.1 = f32[] add(x.1, y.1)
781   }
782 
783   ENTRY reduce.1 {
784     parameter = f32[1048576] parameter(0)
785     init_value = f32[] constant(0)
786     ROOT reduce = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum
787   }
788   )";
789   auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
790   auto expected_ir = is_built_with_rocm_ ? R"(
791 ; CHECK: initial_value_addr = internal unnamed_addr addrspace({{[0-9]*}}) global [1024 x float] undef, align 4
792   )"
793                                          : R"(
794 ; CHECK: shared_cache = private unnamed_addr addrspace({{[0-9]*}}) global [1 x [1 x [2 x float]]]
795   )";
796   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
797                      /*match_optimized_ir=*/true);
798 }
799 
TEST_F(GpuKernelTilingTest,ReductionInputTooLarge)800 TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) {
801   const char *const kHloString = R"(
802   HloModule RowReduce
803 
804   Sum {
805     x.1 = f32[] parameter(0)
806     y.1 = f32[] parameter(1)
807     ROOT add.1 = f32[] add(x.1, y.1)
808   }
809 
810   ENTRY reduce.1 {
811     parameter = f32[4,1048576,1024,1024] parameter(0)
812     init_value = f32[] constant(0)
813     ROOT reduce = f32[4,1048576,1024] reduce(parameter, init_value), dimensions={3}, to_apply=Sum
814   }
815   )";
816   auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
817   Status status = CompileToExecutable(std::move(hlo_module)).status();
818   EXPECT_EQ(status.code(), tensorflow::error::Code::FAILED_PRECONDITION);
819   EXPECT_THAT(
820       status.error_message(),
821       ::testing::HasSubstr(
822           "Number of physical blocks (4294967296) does not fit in an i32"));
823 }
824 
825 }  // namespace
826 }  // namespace gpu
827 }  // namespace xla
828