• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/error_spec.h"
17 #include "tensorflow/compiler/xla/service/gpu/gemm_broadcast_folding_rewriter.h"
18 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
19 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
20 #include "tensorflow/core/platform/test.h"
21 
22 namespace xla {
23 namespace gpu {
24 
25 namespace {
26 
27 class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest {};
28 
TEST_F(GemmBroadcastFoldingRewriteTest,BroadcastedStridedRewriteRhs)29 TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteRhs) {
30   const char* hlo_text = R"(
31 HloModule BroadcastedInput
32 
33 ENTRY AddDotsFunc {
34   x = f32[3,2,2]{2,1,0} parameter(0)
35   y = f32[2,2]{1,0} parameter(1)
36   y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
37   ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
38 }
39 
40 )";
41 
42   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
43   MatchOptimizedHlo(hlo_text,
44                     R"(
45 
46 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[3,2,2], y: f32[2,2]) -> f32[3,2,2] {
47 ; CHECK-NEXT:    %x = f32[3,2,2]{2,1,0} parameter(0)
48 ; CHECK-NEXT:    %y = f32[2,2]{1,0} parameter(1)
49 ; CHECK-NEXT:    ROOT %cublas-batch-gemm.1 = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{
50 ; CHECK-DAG:       \"alpha_real\":1
51 ; CHECK-DAG:       \"alpha_imag\":0
52 ; CHECK-DAG:       \"beta\":0
53 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[]}
54 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
55 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
56 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
57       )");
58 }
59 
60 TEST_F(GemmBroadcastFoldingRewriteTest, BroadcastedStridedRewriteLhs) {
61   const char* hlo_text = R"(
62 HloModule BroadcastedInput
63 
64 ENTRY AddDotsFunc {
65   x = f32[2,2]{1,0} parameter(0)
66   y = f32[3,2,2]{2,1,0} parameter(1)
67   x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
68   ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
69 }
70 
71 )";
72 
73   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
74   MatchOptimizedHlo(hlo_text,
75                     R"(
76 
77 ; CHECK-LABEL: ENTRY %AddDotsFunc (x: f32[2,2], y: f32[3,2,2]) -> f32[3,2,2] {
78 ; CHECK-NEXT:    %x = f32[2,2]{1,0} parameter(0)
79 ; CHECK-NEXT:    %y = f32[3,2,2]{2,1,0} parameter(1)
80 ; CHECK-NEXT:    ROOT %cublas-batch-gemm.1 = f32[3,2,2]{2,1,0} custom-call(%x, %y), custom_call_target="__cublas$gemm", backend_config="{
81 ; CHECK-DAG:       \"alpha_real\":1
82 ; CHECK-DAG:       \"alpha_imag\":0
83 ; CHECK-DAG:       \"beta\":0
84 ; CHECK-DAG:       \"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[\"0\"]}
85 ; CHECK-DAG:       \"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]}
86 ; CHECK-DAG:       \"epilogue\":\"DEFAULT\"
87 ; CHECK-DAG:       \"selected_algorithm\":\"{{-?[0-9]+}}\"
88       )");
89 }
90 
91 TEST_F(GemmBroadcastFoldingRewriteTest,
92        BroadcastedStridedRewriteRhsPassChanged) {
93   const char* hlo_text = R"(
94 HloModule BroadcastedInput
95 
96 ENTRY AddDotsFunc {
97   x = f32[3,2,2]{2,1,0} parameter(0)
98   y = f32[2,2]{1,0} parameter(1)
99   y_broadcast = f32[3,2,2]{2,1,0} broadcast(y), dimensions={1,2}
100   ROOT dot_a = f32[3,2,2]{2,1,0} dot(x, y_broadcast), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
101 }
102 
103 )";
104 
105   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
106                           ParseAndReturnVerifiedModule(hlo_text));
107   // Use GemmRewriter to generate cublasGemm call.
108   GemmRewriter gemm_rewriter;
109   TF_ASSERT_OK_AND_ASSIGN(bool changed,
110                           this->RunHloPass(&gemm_rewriter, module.get()));
111   EXPECT_TRUE(changed);
112   GemmBroadcastFoldingRewriter pass;
113   TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
114   EXPECT_TRUE(changed);
115 }
116 
117 TEST_F(GemmBroadcastFoldingRewriteTest,
118        BroadcastedStridedRewriteLhsPassChanged) {
119   const char* hlo_text = R"(
120 HloModule BroadcastedInput
121 
122 ENTRY AddDotsFunc {
123   x = f32[2,2]{1,0} parameter(0)
124   y = f32[3,2,2]{2,1,0} parameter(1)
125   x_broadcast = f32[3,2,2]{2,1,0} broadcast(x), dimensions={1,2}
126   ROOT dot_a = f32[3,2,2]{2,1,0} dot(x_broadcast, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1}
127 }
128 
129 )";
130 
131   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
132                           ParseAndReturnVerifiedModule(hlo_text));
133   // Use GemmRewriter to generate cublasGemm call.
134   GemmRewriter gemm_rewriter;
135   TF_ASSERT_OK_AND_ASSIGN(bool changed,
136                           this->RunHloPass(&gemm_rewriter, module.get()));
137   EXPECT_TRUE(changed);
138   GemmBroadcastFoldingRewriter pass;
139   TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
140   EXPECT_TRUE(changed);
141 }
142 
143 TEST_F(GemmBroadcastFoldingRewriteTest, LHSBatchDimNonZero) {
144   const char* hlo_text = R"(
145 HloModule LHSBatchDimNonZero
146 
147 ENTRY %LHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
148   %Arg_1 = f32[4,3]{1,0} parameter(0)
149   %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
150   %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
151   ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[7,4,3]{2,1,0} %broadcast.22, f32[4,7,3]{2,1,0} %Arg_2), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={2}
152 }
153 )";
154   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
155   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
156                           ParseAndReturnVerifiedModule(hlo_text));
157   // Use GemmRewriter to generate cublasGemm call.
158   GemmRewriter gemm_rewriter;
159   TF_ASSERT_OK_AND_ASSIGN(bool changed,
160                           this->RunHloPass(&gemm_rewriter, module.get()));
161   EXPECT_TRUE(changed);
162   GemmBroadcastFoldingRewriter pass;
163   TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
164   EXPECT_FALSE(changed);
165 }
166 
167 TEST_F(GemmBroadcastFoldingRewriteTest, RHSBatchDimNonZero) {
168   const char* hlo_text = R"(
169 HloModule RHSBatchDimNonZero
170 
171 ENTRY %RHSBatchDimNonZero (Arg_1: f32[4,3], Arg_2: f32[4,7,3]) -> f32[4,7,7] {
172   %Arg_1 = f32[4,3]{1,0} parameter(0)
173   %Arg_2 = f32[4,7,3]{2,1,0} parameter(1)
174   %broadcast.22 = f32[7,4,3]{2,1,0} broadcast(f32[4,3]{1,0} %Arg_1), dimensions={1,2}
175   ROOT %dot.24 = f32[4,7,7]{2,1,0} dot(f32[4,7,3]{2,1,0} %Arg_2, f32[7,4,3]{2,1,0} %broadcast.22), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={1}, rhs_contracting_dims={2}
176 }
177 )";
178   EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
179   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
180                           ParseAndReturnVerifiedModule(hlo_text));
181   GemmRewriter gemm_rewriter;
182   TF_ASSERT_OK_AND_ASSIGN(bool changed,
183                           this->RunHloPass(&gemm_rewriter, module.get()));
184   EXPECT_TRUE(changed);
185   GemmBroadcastFoldingRewriter pass;
186   TF_ASSERT_OK_AND_ASSIGN(changed, this->RunHloPass(&pass, module.get()));
187   EXPECT_FALSE(changed);
188 }
189 
190 }  // namespace
191 }  // namespace gpu
192 }  // namespace xla
193