1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/error_spec.h"
17 #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
18 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
19 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
20 #include "tensorflow/core/platform/test.h"
21
22 namespace xla {
23 namespace gpu {
24
25 namespace {
26
27 class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest {};
28
TEST_F(GemmBroadcastFoldingRewriteTest,BroadcastedStridedRewriteRhs)29 TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteRhs) {
30 const char* hlo_text = R"(
31 HloModule BroadcastedInput
32
33 ENTRY AddDotsFunc {
34 x = f32[3,2,2]{2,1,0} parameter(0)
35 y = f32[2,2]{1,0} parameter(1)
36 y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
37 ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
38 }
39
40 )";
41
42 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
43 MatchOptimizedHlo(hlo_text,
44 R"(
45
46 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,2], y: f32[2,2]) -> f32[3,2,2] {
47 ; CHECK-NEXT: %x = f32[3,2,2]{2,1,0} parameter(0)
48 ; CHECK-NEXT: %y = f32[2,2]{1,0} parameter(1)
49 ; CHECK-NEXT: ROOT %cublas-batch-gemm.1 = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{
50 ; CHECK-DAG: \"alpha_real\":1
51 ; CHECK-DAG: \"alpha_imag\":0
52 ; CHECK-DAG: \"beta\":0
53 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[]}
54 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
55 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
56 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
57 )");
58 }
59
60 TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteLhs) {
61 const char* hlo_text = R"(
62 HloModule BroadcastedInput
63
64 ENTRY AddDotsFunc {
65 x = f32[2,2]{1,0} parameter(0)
66 y = f32[3,2,2]{2,1,0} parameter(1)
67 x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
68 ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
69 }
70
71 )";
72
73 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
74 MatchOptimizedHlo(hlo_text,
75 R"(
76
77 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[3,2,2]) -> f32[3,2,2] {
78 ; CHECK-NEXT: %x = f32[2,2]{1,0} parameter(0)
79 ; CHECK-NEXT: %y = f32[3,2,2]{2,1,0} parameter(1)
80 ; CHECK-NEXT: ROOT %cublas-batch-gemm.1 = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{
81 ; CHECK-DAG: \"alpha_real\":1
82 ; CHECK-DAG: \"alpha_imag\":0
83 ; CHECK-DAG: \"beta\":0
84 ; CHECK-DAG: \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[\"0\"]}
85 ; CHECK-DAG: \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
86 ; CHECK-DAG: \"epilogue\":\"DEFAULT\"
87 ; CHECK-DAG: \"selected_algorithm\":\"{{-?[0-9]+}}\"
88 )");
89 }
90
91 TEST_F(GemmBroadcastFoldingRewriteTest,
92 BroadcastedStridedRewriteRhsPassChanged) {
93 const char* hlo_text = R"(
94 HloModule BroadcastedInput
95
96 ENTRY AddDotsFunc {
97 x = f32[3,2,2]{2,1,0} parameter(0)
98 y = f32[2,2]{1,0} parameter(1)
99 y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
100 ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
101 }
102
103 )";
104
105 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
106 ParseAndReturnVerifiedModule(hlo_text));
107 // Use GemmRewriter to generate cublasGemm call.
108 GemmRewriter gemm_rewriter;
109 TF_ASSERT_OK_AND_ASSIGN(bool changed,
110 this->RunHloPass(&gemm_rewriter, module.get()));
111 EXPECT_TRUE(changed);
112 GemmBroadcastFoldingRewriter pass;
113 TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
114 EXPECT_TRUE(changed);
115 }
116
117 TEST_F(GemmBroadcastFoldingRewriteTest,
118 BroadcastedStridedRewriteLhsPassChanged) {
119 const char* hlo_text = R"(
120 HloModule BroadcastedInput
121
122 ENTRY AddDotsFunc {
123 x = f32[2,2]{1,0} parameter(0)
124 y = f32[3,2,2]{2,1,0} parameter(1)
125 x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
126 ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
127 }
128
129 )";
130
131 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
132 ParseAndReturnVerifiedModule(hlo_text));
133 // Use GemmRewriter to generate cublasGemm call.
134 GemmRewriter gemm_rewriter;
135 TF_ASSERT_OK_AND_ASSIGN(bool changed,
136 this->RunHloPass(&gemm_rewriter, module.get()));
137 EXPECT_TRUE(changed);
138 GemmBroadcastFoldingRewriter pass;
139 TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
140 EXPECT_TRUE(changed);
141 }
142
143 TEST_F(GemmBroadcastFoldingRewriteTest, LHSBatchDimNonZero) {
144 const char* hlo_text = R"(
145 HloModule LHSBatchDimNonZero
146
147 ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
148 %Arg_1 = f32[4,3]{1,0} parameter(0)
149 %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
150 %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
151 ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[7,4,3]{2,1,0} %broadcast.22, f32[4,7,3]{2,1,0} %Arg_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
152 }
153 )";
154 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
155 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
156 ParseAndReturnVerifiedModule(hlo_text));
157 // Use GemmRewriter to generate cublasGemm call.
158 GemmRewriter gemm_rewriter;
159 TF_ASSERT_OK_AND_ASSIGN(bool changed,
160 this->RunHloPass(&gemm_rewriter, module.get()));
161 EXPECT_TRUE(changed);
162 GemmBroadcastFoldingRewriter pass;
163 TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
164 EXPECT_FALSE(changed);
165 }
166
167 TEST_F(GemmBroadcastFoldingRewriteTest, RHSBatchDimNonZero) {
168 const char* hlo_text = R"(
169 HloModule RHSBatchDimNonZero
170
171 ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
172 %Arg_1 = f32[4,3]{1,0} parameter(0)
173 %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
174 %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
175 ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[4,7,3]{2,1,0} %Arg_2, f32[7,4,3]{2,1,0} %broadcast.22), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2}
176 }
177 )";
178 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
179 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
180 ParseAndReturnVerifiedModule(hlo_text));
181 GemmRewriter gemm_rewriter;
182 TF_ASSERT_OK_AND_ASSIGN(bool changed,
183 this->RunHloPass(&gemm_rewriter, module.get()));
184 EXPECT_TRUE(changed);
185 GemmBroadcastFoldingRewriter pass;
186 TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
187 EXPECT_FALSE(changed);
188 }
189
190 } // namespace
191 } // namespace gpu
192 } // namespace xla
193