• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
17 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
18 
19 namespace xla {
20 namespace gpu {
21 namespace {
22 
23 class FusionLogicalIndexTest : public GpuCodegenTest {};
24 
TEST_F(FusionLogicalIndexTest,FusionLogicalIndexStore)25 TEST_F(FusionLogicalIndexTest, FusionLogicalIndexStore) {
26   const char* hlo_text = R"(
27 HloModule TestModule
28 
29 fused_computation.18 {
30   select.12 = f32[768000,16]{1,0} parameter(0)
31   select.13 = f32[768000,16]{1,0} parameter(1)
32   maximum.1 = f32[768000,16]{1,0} maximum(select.13, select.12)
33   ROOT reshape.437 = f32[1,480,400,64]{2,1,3,0} reshape(maximum.1)
34 }
35 
36 
37 ENTRY entry {
38     select.12 = f32[768000,16]{1,0} parameter(0)
39     select.13 = f32[768000,16]{1,0} parameter(1)
40     ROOT fusion.18 = f32[1,480,400,64]{2,1,3,0} fusion(select.12, select.13), kind=kLoop, calls=fused_computation.18
41 }
42 )";
43 
44   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
45 
46   auto expected_ir = is_built_with_rocm_ ? R"(
47 ; CHECK:  %[[block_id:.*]] = call i32 @llvm.amdgcn.workgroup.id.x(), !range !1
48 ; CHECK:  %[[thread_id:.*]] = call i32 @llvm.amdgcn.workitem.id.x(), !range !2
49 ; CHECK:  %[[block_start_index:.*]] = mul nuw nsw i32 %[[block_id]], [[block_size:.*]]
50 ; CHECK:  %[[linear_index:.*]] = add nuw nsw i32 %[[block_start_index]], %[[thread_id]]
51 ; CHECK:  %[[index_1:.*]] = urem i32 %[[linear_index]], 64
52 ; CHECK:  %[[base_1:.*]] = udiv i32 %[[linear_index]], 64
53 ; CHECK:  %[[index_3:.*]] = urem i32 %[[base_1]], 400
54 ; CHECK:  %[[base_3:.*]] = udiv i32 %[[base_1]], 400
55 ; CHECK:  %[[index_2:.*]] = urem i32 %[[base_3]], 480
56 ; CHECK:  %[[pointer:.*]] = getelementptr inbounds [1 x [64 x [480 x [400 x float]]]], ptr %2, i32 0, i32 0, i32 %[[index_1]], i32 %[[index_2]], i32 %[[index_3]]
57 ; CHECK:  store float %[[result_value:.*]], ptr %[[pointer]], align 4
58   )"
59                                          : R"(
60 ; CHECK:  %[[block_id:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
61 ; CHECK:  %[[thread_id:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
62 ; CHECK:  %[[block_start_index:.*]] = mul nuw nsw i32 %[[block_id]], [[block_size:.*]]
63 ; CHECK:  %[[linear_index:.*]] = add nuw nsw i32 %[[block_start_index]], %[[thread_id]]
64 ; CHECK:  %[[index_1:.*]] = urem i32 %[[linear_index]], 64
65 ; CHECK:  %[[base_1:.*]] = udiv i32 %[[linear_index]], 64
66 ; CHECK:  %[[index_3:.*]] = urem i32 %[[base_1]], 400
67 ; CHECK:  %[[base_3:.*]] = udiv i32 %[[base_1]], 400
68 ; CHECK:  %[[index_2:.*]] = urem i32 %[[base_3]], 480
69 ; CHECK:  %[[pointer:.*]] = getelementptr inbounds [1 x [64 x [480 x [400 x float]]]], ptr %2, i32 0, i32 0, i32 %[[index_1]], i32 %[[index_2]], i32 %[[index_3]]
70 ; CHECK:  store float %[[result_value:.*]], ptr %[[pointer]], align 4
71   )";
72 
73   CompileAndVerifyIr(hlo_text, expected_ir);
74 }
75 
76 }  // namespace
77 }  // namespace gpu
78 }  // namespace xla
79