• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "absl/strings/str_replace.h"
19 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
20 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
23 #include "tensorflow/compiler/xla/service/hlo_parser.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/compiler/xla/tests/filecheck.h"
26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/stream_executor/lib/statusor.h"
30 
31 namespace xla {
32 namespace gpu {
33 
34 namespace {
35 
36 class ReductionVectorizationTest : public GpuCodegenTest {};
37 
38 class ReductionVectorizationNoOptTest : public GpuCodegenTest {
GetDebugOptionsForTest()39   DebugOptions GetDebugOptionsForTest() override {
40     DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
41     // The test MultiOutputStore contain a MOF fusion and XLA optimizer pass
42     // doesn't like this.
43     debug_options.set_xla_disable_all_hlo_passes(true);
44     return debug_options;
45   }
46 };
47 
TEST_F(ReductionVectorizationNoOptTest,MultiOutputStore)48 TEST_F(ReductionVectorizationNoOptTest, MultiOutputStore) {
49   const char* hlo_text = R"(
50 HloModule MultiOutputStore
51 
52 %add_f32 {
53   %x = f32[] parameter(0)
54   %y = f32[] parameter(1)
55   ROOT %add = f32[] add(%x, %y)
56 }
57 
58 %fused_computation {
59   %param_0 = f32[2,384,1024] parameter(0)
60   %param_1 = f32[2,384] parameter(1)
61   %constant0 = f32[] constant(0.0009765625)
62   %broadcast0 = f32[2,384] broadcast(%constant0), dimensions={}
63   %multiply0 = f32[2,384] multiply(%param_1, %broadcast0)
64   %broadcast1 = f32[2,384,1024] broadcast(%multiply0), dimensions={0,1}
65   %subtract = f32[2,384,1024] subtract(%param_0, %broadcast1)
66   %multiply1 = f32[2,384,1024] multiply(%subtract, %subtract)
67   %constant1 = f32[] constant(0)
68   %reduce = f32[2,384] reduce(%multiply1, %constant1), dimensions={2}, to_apply=%add_f32
69   ROOT %tuple = (f32[2,384], f32[2,384,1024], f32[2,384,1024]) tuple(%reduce, %subtract, %broadcast1)
70 }
71 
72 ENTRY %cluster {
73   %param0 = f32[2,384,1024] parameter(0)
74   %param1 =  f32[2,384] parameter(1)
75   ROOT %fusion = (f32[2,384], f32[2,384,1024], f32[2,384,1024]) fusion(%param0, %param1), kind=kInput, calls=%fused_computation
76 }
77 )";
78 
79   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
80                           ParseAndReturnVerifiedModule(hlo_text));
81   std::string expected = R"(
82 CHECK: ld.global.nc.v2.f32
83 CHECK: st.global.v2.f32
84 CHECK: st.global.v2.f32
85 CHECK: ld.global.nc.v2.f32
86 CHECK: st.global.v2.f32
87 CHECK: st.global.v2.f32
88 CHECK: ld.global.nc.v2.f32
89 CHECK: st.global.v2.f32
90 CHECK: st.global.v2.f32
91 CHECK: ld.global.nc.v2.f32
92 CHECK: st.global.v2.f32
93 CHECK: st.global.v2.f32
94 )";
95   CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected);
96 
97   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
98 }
99 
TEST_F(ReductionVectorizationTest,NoVectorizationForBlockSmallerThanWarpSize)100 TEST_F(ReductionVectorizationTest, NoVectorizationForBlockSmallerThanWarpSize) {
101   const char* hlo_text = R"(
102 HloModule SlowModule
103 
104 %search_fn (x: f32[], y: f32[]) -> f32[] {
105   %x = f32[] parameter(0)
106   %y = f32[] parameter(1)
107   ROOT %add0 = f32[] add(f32[] %x, f32[] %y)
108 }
109 
110 ENTRY %fused_computation.371 (param_0: f32[6400,4,8,32]) -> f32[6400,4,8] {
111   %param_0 = f32[6400,4,8,32]{3,2,1,0} parameter(0)
112   %constant_0 = f32[] constant(0.0)
113   ROOT %reduce.277 = f32[6400,4,8]{2,1,0} reduce(f32[6400,4,8,32]{3,2,1,0} %param_0, f32[] %constant_0), dimensions={3}, to_apply=%search_fn
114 }
115 )";
116 
117   std::string expected_optimized_llvm_ir = R"(
118 CHECK:  %[[thread_id:.*]] = tail call i32 X_THREAD
119 CHECK:  %[[masked_thread_id:.*]] = and i32 %[[thread_id]], 31
120 // Verify that there is no comparison masking half the warp.
121 CHECK-NOT: icmp ult i32 %[[masked_thread_id]], 16
122 // Verify that we only do one warp reducton by checking that there are 6
123 // shfl.sync corresponding to 1 declaration and 5 shuffle instructions.  The
124 // second warp reduction was originally produced for inter-warp reduction
125 // which we have now optimized away.
126 CHECK-COUNT-6: SHUFFLE
127 CHECK-NOT: SHUFFLE
128 )";
129 
130   expected_optimized_llvm_ir = absl::StrReplaceAll(
131       expected_optimized_llvm_ir,
132       {{"X_THREAD", is_built_with_rocm_ ? "@llvm.amdgcn.workitem.id.x"
133                                         : "@llvm.nvvm.read.ptx.sreg.tid.x"},
134        {"SHUFFLE", is_built_with_rocm_ ? "llvm.amdgcn.ds.bpermute"
135                                        : "llvm.nvvm.shfl.sync.down.f32"}});
136 
137   CompileAndVerifyIr(hlo_text, expected_optimized_llvm_ir, true);
138 
139   // Check that there is a single scalar load.
140   const char* expected_ptx = R"(
141 CHECK: ld.global.nc.f32
142 CHECK: shfl.sync.down
143 CHECK-NOT: ld.global.nc.f32
144 CHECK-NOT: ld.global.v2.f32
145 )";
146   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
147                           ParseAndReturnVerifiedModule(hlo_text));
148   CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
149   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
150 }
151 
152 }  // namespace
153 }  // namespace gpu
154 }  // namespace xla
155