1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_support_utils.h"
17
18 #include <memory>
19
20 #include "absl/status/status.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/status_matchers.h"
36 #include "tensorflow/stream_executor/device_description.h"
37 #include "tensorflow/stream_executor/dnn.h"
38
39 namespace xla {
40 namespace gpu {
41 namespace {
42
43 using ::tensorflow::testing::IsOkAndHolds;
44
45 class CudnnSupportUtilsTest : public HloTestBase {
46 public:
47 // Gets the custom call with `target` from the `module`. Expects that there is
48 // one and only one matching call.
GetCustomCall(xla::VerifiedHloModule * module,absl::string_view target)49 StatusOr<HloCustomCallInstruction*> GetCustomCall(
50 xla::VerifiedHloModule* module, absl::string_view target) {
51 HloCustomCallInstruction* call = nullptr;
52 for (HloComputation* comp : module->MakeNonfusionComputations()) {
53 for (HloInstruction* inst : comp->instructions()) {
54 if (inst->IsCustomCall(target)) {
55 VLOG(1) << inst->ToString();
56 if (call != nullptr) {
57 return tensorflow::errors::FailedPrecondition(
58 "Found more than one custom call.");
59 }
60 call = Cast<HloCustomCallInstruction>(inst);
61 }
62 }
63 }
64 if (call == nullptr) {
65 return tensorflow::errors::FailedPrecondition(
66 "Did not find any matching custom call.");
67 }
68 return call;
69 }
70 };
71
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedIntegerConvolutionCheckVectorSize)72 TEST_F(CudnnSupportUtilsTest,
73 CudnnSupportsOptimizedIntegerConvolutionCheckVectorSize) {
74 auto module = ParseAndReturnVerifiedModule(R"(
75 HloModule TestModule
76
77 ENTRY TestComputation {
78 input = s8[8,10,10,128] parameter(0)
79 filter = s8[2,2,128,128] parameter(1)
80 ROOT result = (s8[8,10,10,128], u8[0]) custom-call(input, filter),
81 window={size=2x2}, dim_labels=b01f_01io->b01f,
82 custom_call_target="__cudnn$convForward"
83 })")
84 .ValueOrDie();
85
86 HloCustomCallInstruction* conv;
87 TF_ASSERT_OK_AND_ASSIGN(conv,
88 GetCustomCall(module.get(), "__cudnn$convForward"));
89
90 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
91 IsOkAndHolds(true));
92 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
93 IsOkAndHolds(true));
94 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 7),
95 IsOkAndHolds(false));
96 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 1),
97 IsOkAndHolds(false)); // 1 is not considered a vector size
98 }
99
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedIntegerConvolutionCheckComputeCapability)100 TEST_F(CudnnSupportUtilsTest,
101 CudnnSupportsOptimizedIntegerConvolutionCheckComputeCapability) {
102 auto module = ParseAndReturnVerifiedModule(R"(
103 HloModule TestModule
104
105 ENTRY TestComputation {
106 input = s8[8,10,10,128] parameter(0)
107 filter = s8[2,2,128,128] parameter(1)
108 ROOT result = (s8[8,10,10,128], u8[0]) custom-call(input, filter),
109 window={size=2x2}, dim_labels=b01f_01io->b01f,
110 custom_call_target="__cudnn$convForward"
111 })")
112 .ValueOrDie();
113
114 HloCustomCallInstruction* conv;
115 TF_ASSERT_OK_AND_ASSIGN(conv,
116 GetCustomCall(module.get(), "__cudnn$convForward"));
117
118 // cc6.1 allows for int8x4
119 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({6, 0}, *conv, 4),
120 IsOkAndHolds(false));
121 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({6, 1}, *conv, 4),
122 IsOkAndHolds(true));
123
124 // cc7.5+ allows for int8x32
125 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 4}, *conv, 32),
126 IsOkAndHolds(false));
127 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
128 IsOkAndHolds(true));
129 }
130
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedIntegerConvolutionCheckKind)131 TEST_F(CudnnSupportUtilsTest,
132 CudnnSupportsOptimizedIntegerConvolutionCheckKind) {
133 auto moduleFwd = ParseAndReturnVerifiedModule(R"(
134 HloModule TestModule
135
136 ENTRY TestComputation {
137 input = s8[32,10,10,64] parameter(0)
138 filter = s8[2,2,64,128] parameter(1)
139 ROOT result = (s8[32,10,10,128], u8[0]) custom-call(input, filter),
140 window={size=2x2}, dim_labels=b01f_01io->b01f,
141 custom_call_target="__cudnn$convForward"
142 })")
143 .ValueOrDie();
144
145 HloCustomCallInstruction* conv;
146 TF_ASSERT_OK_AND_ASSIGN(
147 conv, GetCustomCall(moduleFwd.get(), "__cudnn$convForward"));
148 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
149 IsOkAndHolds(true));
150
151 auto moduleBwdFilter = ParseAndReturnVerifiedModule(R"(
152 HloModule TestModule
153
154 ENTRY TestComputation {
155 input = f16[10,20,30,41] parameter(0)
156 output = f16[10,20,30,40] parameter(1)
157 result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
158 window={size=2x2}, dim_labels=b01f_01io->b01f,
159 custom_call_target="__cudnn$convBackwardFilter"
160 ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
161 })")
162 .ValueOrDie();
163
164 TF_ASSERT_OK_AND_ASSIGN(
165 conv, GetCustomCall(moduleBwdFilter.get(), "__cudnn$convBackwardFilter"));
166 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
167 IsOkAndHolds(false));
168
169 auto moduleBwdInput = ParseAndReturnVerifiedModule(R"(
170 HloModule TestModule
171
172 ENTRY TestComputation {
173 output = f16[10,20,30,40] parameter(0)
174 filter = f16[2,2,41,40] parameter(1)
175 result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
176 window={size=2x2}, dim_labels=b01f_01io->b01f,
177 custom_call_target="__cudnn$convBackwardInput"
178 ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
179 })")
180 .ValueOrDie();
181
182 TF_ASSERT_OK_AND_ASSIGN(
183 conv, GetCustomCall(moduleBwdInput.get(), "__cudnn$convBackwardInput"));
184 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
185 IsOkAndHolds(false));
186 }
187
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckTypes)188 TEST_F(CudnnSupportUtilsTest,
189 CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckTypes) {
190 auto moduleS8InOut = ParseAndReturnVerifiedModule(R"(
191 HloModule TestModule
192
193 ENTRY TestComputation {
194 input = s8[32,10,10,64] parameter(0)
195 filter = s8[2,2,64,128] parameter(1)
196 ROOT result = (s8[32,10,10,128], u8[0]) custom-call(input, filter),
197 window={size=2x2}, dim_labels=b01f_01io->b01f,
198 custom_call_target="__cudnn$convForward"
199 })")
200 .ValueOrDie();
201 HloCustomCallInstruction* conv;
202 TF_ASSERT_OK_AND_ASSIGN(
203 conv, GetCustomCall(moduleS8InOut.get(), "__cudnn$convForward"));
204 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
205 IsOkAndHolds(true));
206 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
207 IsOkAndHolds(true));
208
209 auto moduleS8InF32Out = ParseAndReturnVerifiedModule(R"(
210 HloModule TestModule
211
212 ENTRY TestComputation {
213 input = s8[32,10,10,64] parameter(0)
214 filter = s8[2,2,64,128] parameter(1)
215 ROOT result = (f32[32,10,10,128], u8[0]) custom-call(input, filter),
216 window={size=2x2}, dim_labels=b01f_01io->b01f,
217 custom_call_target="__cudnn$convForward"
218 })")
219 .ValueOrDie();
220 TF_ASSERT_OK_AND_ASSIGN(
221 conv, GetCustomCall(moduleS8InF32Out.get(), "__cudnn$convForward"));
222 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
223 IsOkAndHolds(true));
224 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
225 IsOkAndHolds(false)); // imma output must also be int8_t
226
227 auto moduleF32InF32Out = ParseAndReturnVerifiedModule(R"(
228 HloModule TestModule
229
230 ENTRY TestComputation {
231 input = f32[32,10,10,64] parameter(0)
232 filter = f32[2,2,64,128] parameter(1)
233 ROOT result = (f32[32,10,10,128], u8[0]) custom-call(input, filter),
234 window={size=2x2}, dim_labels=b01f_01io->b01f,
235 custom_call_target="__cudnn$convForward"
236 })")
237 .ValueOrDie();
238 TF_ASSERT_OK_AND_ASSIGN(
239 conv, GetCustomCall(moduleF32InF32Out.get(), "__cudnn$convForward"));
240 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
241 IsOkAndHolds(false));
242 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
243 IsOkAndHolds(false));
244 }
245
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDims)246 TEST_F(CudnnSupportUtilsTest,
247 CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDims) {
248 // This 3d conv should be rejected
249 auto module = ParseAndReturnVerifiedModule(R"(
250 HloModule TestModule
251
252 ENTRY TestComputation {
253 input = s8[32,10,10,10,64] parameter(0)
254 filter = s8[2,2,2,64,128] parameter(1)
255 ROOT result = (s8[32,10,10,10,128], u8[0]) custom-call(input, filter),
256 window={size=2x2}, dim_labels=b012f_012io->b012f,
257 custom_call_target="__cudnn$convForward"
258 })")
259 .ValueOrDie();
260 HloCustomCallInstruction* conv;
261 TF_ASSERT_OK_AND_ASSIGN(conv,
262 GetCustomCall(module.get(), "__cudnn$convForward"));
263
264 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
265 IsOkAndHolds(false));
266 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
267 IsOkAndHolds(false));
268 }
269
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDilation)270 TEST_F(CudnnSupportUtilsTest,
271 CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDilation) {
272 auto module = ParseAndReturnVerifiedModule(R"(
273 HloModule TestModule
274
275 ENTRY TestComputation {
276 input = s8[32,10,10,64] parameter(0)
277 filter = s8[2,2,64,128] parameter(1)
278 ROOT result = (s8[32,20,20,128], u8[0]) custom-call(input, filter),
279 window={size=2x2 rhs_dilate=2x2}, dim_labels=b01f_01io->b01f,
280 custom_call_target="__cudnn$convForward"
281 })")
282 .ValueOrDie();
283 HloCustomCallInstruction* conv;
284 TF_ASSERT_OK_AND_ASSIGN(conv,
285 GetCustomCall(module.get(), "__cudnn$convForward"));
286 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
287 IsOkAndHolds(false));
288 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
289 IsOkAndHolds(false));
290 }
291
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckAlgo1Dims)292 TEST_F(CudnnSupportUtilsTest,
293 CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckAlgo1Dims) {
294 auto moduleFilterCoversInput = ParseAndReturnVerifiedModule(R"(
295 HloModule TestModule
296
297 ENTRY TestComputation {
298 input = s8[32,2,2,64] parameter(0)
299 filter = s8[3,3,64,128] parameter(1)
300 ROOT result = (s8[32,2,2,128], u8[0]) custom-call(input, filter),
301 window={size=3x3}, dim_labels=b01f_01io->b01f,
302 custom_call_target="__cudnn$convForward"
303 })")
304 .ValueOrDie();
305 HloCustomCallInstruction* conv;
306 TF_ASSERT_OK_AND_ASSIGN(conv, GetCustomCall(moduleFilterCoversInput.get(),
307 "__cudnn$convForward"));
308 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
309 IsOkAndHolds(true));
310 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
311 IsOkAndHolds(false));
312
313 auto moduleFilterAlmostCoversInput = ParseAndReturnVerifiedModule(R"(
314 HloModule TestModule
315
316 ENTRY TestComputation {
317 input = s8[32,3,3,64] parameter(0)
318 filter = s8[3,3,64,128] parameter(1)
319 ROOT result = (s8[32,3,3,128], u8[0]) custom-call(input, filter),
320 window={size=3x3}, dim_labels=b01f_01io->b01f,
321 custom_call_target="__cudnn$convForward"
322 })")
323 .ValueOrDie();
324 TF_ASSERT_OK_AND_ASSIGN(conv,
325 GetCustomCall(moduleFilterAlmostCoversInput.get(),
326 "__cudnn$convForward"));
327 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
328 IsOkAndHolds(true));
329 EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
330 IsOkAndHolds(true));
331 }
332
333 } // namespace
334 } // namespace gpu
335 } // namespace xla
336