• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_support_utils.h"
17 
18 #include <memory>
19 
20 #include "absl/status/status.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/status_matchers.h"
36 #include "tensorflow/stream_executor/device_description.h"
37 #include "tensorflow/stream_executor/dnn.h"
38 
39 namespace xla {
40 namespace gpu {
41 namespace {
42 
43 using ::tensorflow::testing::IsOkAndHolds;
44 
45 class CudnnSupportUtilsTest : public HloTestBase {
46  public:
47   // Gets the custom call with `target` from the `module`. Expects that there is
48   // one and only one matching call.
GetCustomCall(xla::VerifiedHloModule * module,absl::string_view target)49   StatusOr<HloCustomCallInstruction*> GetCustomCall(
50       xla::VerifiedHloModule* module, absl::string_view target) {
51     HloCustomCallInstruction* call = nullptr;
52     for (HloComputation* comp : module->MakeNonfusionComputations()) {
53       for (HloInstruction* inst : comp->instructions()) {
54         if (inst->IsCustomCall(target)) {
55           VLOG(1) << inst->ToString();
56           if (call != nullptr) {
57             return tensorflow::errors::FailedPrecondition(
58                 "Found more than one custom call.");
59           }
60           call = Cast<HloCustomCallInstruction>(inst);
61         }
62       }
63     }
64     if (call == nullptr) {
65       return tensorflow::errors::FailedPrecondition(
66           "Did not find any matching custom call.");
67     }
68     return call;
69   }
70 };
71 
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedIntegerConvolutionCheckVectorSize)72 TEST_F(CudnnSupportUtilsTest,
73        CudnnSupportsOptimizedIntegerConvolutionCheckVectorSize) {
74   auto module = ParseAndReturnVerifiedModule(R"(
75   HloModule TestModule
76 
77   ENTRY TestComputation {
78     input = s8[8,10,10,128] parameter(0)
79     filter = s8[2,2,128,128] parameter(1)
80     ROOT result = (s8[8,10,10,128], u8[0]) custom-call(input, filter),
81                   window={size=2x2}, dim_labels=b01f_01io->b01f,
82                   custom_call_target="__cudnn$convForward"
83   })")
84                     .ValueOrDie();
85 
86   HloCustomCallInstruction* conv;
87   TF_ASSERT_OK_AND_ASSIGN(conv,
88                           GetCustomCall(module.get(), "__cudnn$convForward"));
89 
90   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
91               IsOkAndHolds(true));
92   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
93               IsOkAndHolds(true));
94   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 7),
95               IsOkAndHolds(false));
96   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 1),
97               IsOkAndHolds(false));  // 1 is not considered a vector size
98 }
99 
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedIntegerConvolutionCheckComputeCapability)100 TEST_F(CudnnSupportUtilsTest,
101        CudnnSupportsOptimizedIntegerConvolutionCheckComputeCapability) {
102   auto module = ParseAndReturnVerifiedModule(R"(
103   HloModule TestModule
104 
105   ENTRY TestComputation {
106     input = s8[8,10,10,128] parameter(0)
107     filter = s8[2,2,128,128] parameter(1)
108     ROOT result = (s8[8,10,10,128], u8[0]) custom-call(input, filter),
109                   window={size=2x2}, dim_labels=b01f_01io->b01f,
110                   custom_call_target="__cudnn$convForward"
111   })")
112                     .ValueOrDie();
113 
114   HloCustomCallInstruction* conv;
115   TF_ASSERT_OK_AND_ASSIGN(conv,
116                           GetCustomCall(module.get(), "__cudnn$convForward"));
117 
118   // cc6.1 allows for int8x4
119   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({6, 0}, *conv, 4),
120               IsOkAndHolds(false));
121   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({6, 1}, *conv, 4),
122               IsOkAndHolds(true));
123 
124   // cc7.5+ allows for int8x32
125   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 4}, *conv, 32),
126               IsOkAndHolds(false));
127   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
128               IsOkAndHolds(true));
129 }
130 
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedIntegerConvolutionCheckKind)131 TEST_F(CudnnSupportUtilsTest,
132        CudnnSupportsOptimizedIntegerConvolutionCheckKind) {
133   auto moduleFwd = ParseAndReturnVerifiedModule(R"(
134   HloModule TestModule
135 
136   ENTRY TestComputation {
137     input = s8[32,10,10,64] parameter(0)
138     filter = s8[2,2,64,128] parameter(1)
139     ROOT result = (s8[32,10,10,128], u8[0]) custom-call(input, filter),
140                   window={size=2x2}, dim_labels=b01f_01io->b01f,
141                   custom_call_target="__cudnn$convForward"
142   })")
143                        .ValueOrDie();
144 
145   HloCustomCallInstruction* conv;
146   TF_ASSERT_OK_AND_ASSIGN(
147       conv, GetCustomCall(moduleFwd.get(), "__cudnn$convForward"));
148   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
149               IsOkAndHolds(true));
150 
151   auto moduleBwdFilter = ParseAndReturnVerifiedModule(R"(
152   HloModule TestModule
153 
154   ENTRY TestComputation {
155     input = f16[10,20,30,41] parameter(0)
156     output = f16[10,20,30,40] parameter(1)
157     result = (f16[2,2,41,40], u8[0]) custom-call(input, output),
158               window={size=2x2}, dim_labels=b01f_01io->b01f,
159               custom_call_target="__cudnn$convBackwardFilter"
160     ROOT gte = f16[2,2,41,40] get-tuple-element(result), index=0
161   })")
162                              .ValueOrDie();
163 
164   TF_ASSERT_OK_AND_ASSIGN(
165       conv, GetCustomCall(moduleBwdFilter.get(), "__cudnn$convBackwardFilter"));
166   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
167               IsOkAndHolds(false));
168 
169   auto moduleBwdInput = ParseAndReturnVerifiedModule(R"(
170   HloModule TestModule
171 
172   ENTRY TestComputation {
173     output = f16[10,20,30,40] parameter(0)
174     filter = f16[2,2,41,40] parameter(1)
175     result = (f16[10,20,30,41], u8[0]) custom-call(output, filter),
176               window={size=2x2}, dim_labels=b01f_01io->b01f,
177               custom_call_target="__cudnn$convBackwardInput"
178     ROOT gte = f16[10,20,30,41] get-tuple-element(result), index=0
179   })")
180                             .ValueOrDie();
181 
182   TF_ASSERT_OK_AND_ASSIGN(
183       conv, GetCustomCall(moduleBwdInput.get(), "__cudnn$convBackwardInput"));
184   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
185               IsOkAndHolds(false));
186 }
187 
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckTypes)188 TEST_F(CudnnSupportUtilsTest,
189        CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckTypes) {
190   auto moduleS8InOut = ParseAndReturnVerifiedModule(R"(
191   HloModule TestModule
192 
193   ENTRY TestComputation {
194     input = s8[32,10,10,64] parameter(0)
195     filter = s8[2,2,64,128] parameter(1)
196     ROOT result = (s8[32,10,10,128], u8[0]) custom-call(input, filter),
197                   window={size=2x2}, dim_labels=b01f_01io->b01f,
198                   custom_call_target="__cudnn$convForward"
199   })")
200                            .ValueOrDie();
201   HloCustomCallInstruction* conv;
202   TF_ASSERT_OK_AND_ASSIGN(
203       conv, GetCustomCall(moduleS8InOut.get(), "__cudnn$convForward"));
204   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
205               IsOkAndHolds(true));
206   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
207               IsOkAndHolds(true));
208 
209   auto moduleS8InF32Out = ParseAndReturnVerifiedModule(R"(
210   HloModule TestModule
211 
212   ENTRY TestComputation {
213     input = s8[32,10,10,64] parameter(0)
214     filter = s8[2,2,64,128] parameter(1)
215     ROOT result = (f32[32,10,10,128], u8[0]) custom-call(input, filter),
216                   window={size=2x2}, dim_labels=b01f_01io->b01f,
217                   custom_call_target="__cudnn$convForward"
218   })")
219                               .ValueOrDie();
220   TF_ASSERT_OK_AND_ASSIGN(
221       conv, GetCustomCall(moduleS8InF32Out.get(), "__cudnn$convForward"));
222   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
223               IsOkAndHolds(true));
224   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
225               IsOkAndHolds(false));  // imma output must also be int8_t
226 
227   auto moduleF32InF32Out = ParseAndReturnVerifiedModule(R"(
228   HloModule TestModule
229 
230   ENTRY TestComputation {
231     input = f32[32,10,10,64] parameter(0)
232     filter = f32[2,2,64,128] parameter(1)
233     ROOT result = (f32[32,10,10,128], u8[0]) custom-call(input, filter),
234                   window={size=2x2}, dim_labels=b01f_01io->b01f,
235                   custom_call_target="__cudnn$convForward"
236   })")
237                                .ValueOrDie();
238   TF_ASSERT_OK_AND_ASSIGN(
239       conv, GetCustomCall(moduleF32InF32Out.get(), "__cudnn$convForward"));
240   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
241               IsOkAndHolds(false));
242   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
243               IsOkAndHolds(false));
244 }
245 
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDims)246 TEST_F(CudnnSupportUtilsTest,
247        CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDims) {
248   // This 3d conv should be rejected
249   auto module = ParseAndReturnVerifiedModule(R"(
250   HloModule TestModule
251 
252   ENTRY TestComputation {
253     input = s8[32,10,10,10,64] parameter(0)
254     filter = s8[2,2,2,64,128] parameter(1)
255     ROOT result = (s8[32,10,10,10,128], u8[0]) custom-call(input, filter),
256                   window={size=2x2}, dim_labels=b012f_012io->b012f,
257                   custom_call_target="__cudnn$convForward"
258   })")
259                     .ValueOrDie();
260   HloCustomCallInstruction* conv;
261   TF_ASSERT_OK_AND_ASSIGN(conv,
262                           GetCustomCall(module.get(), "__cudnn$convForward"));
263 
264   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
265               IsOkAndHolds(false));
266   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
267               IsOkAndHolds(false));
268 }
269 
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDilation)270 TEST_F(CudnnSupportUtilsTest,
271        CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckDilation) {
272   auto module = ParseAndReturnVerifiedModule(R"(
273   HloModule TestModule
274 
275   ENTRY TestComputation {
276     input = s8[32,10,10,64] parameter(0)
277     filter = s8[2,2,64,128] parameter(1)
278     ROOT result = (s8[32,20,20,128], u8[0]) custom-call(input, filter),
279                   window={size=2x2 rhs_dilate=2x2}, dim_labels=b01f_01io->b01f,
280                   custom_call_target="__cudnn$convForward"
281   })")
282                     .ValueOrDie();
283   HloCustomCallInstruction* conv;
284   TF_ASSERT_OK_AND_ASSIGN(conv,
285                           GetCustomCall(module.get(), "__cudnn$convForward"));
286   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
287               IsOkAndHolds(false));
288   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
289               IsOkAndHolds(false));
290 }
291 
TEST_F(CudnnSupportUtilsTest,CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckAlgo1Dims)292 TEST_F(CudnnSupportUtilsTest,
293        CudnnSupportsOptimizedVectorizedIntegerConvolutionCheckAlgo1Dims) {
294   auto moduleFilterCoversInput = ParseAndReturnVerifiedModule(R"(
295   HloModule TestModule
296 
297   ENTRY TestComputation {
298     input = s8[32,2,2,64] parameter(0)
299     filter = s8[3,3,64,128] parameter(1)
300     ROOT result = (s8[32,2,2,128], u8[0]) custom-call(input, filter),
301                   window={size=3x3}, dim_labels=b01f_01io->b01f,
302                   custom_call_target="__cudnn$convForward"
303   })")
304                                      .ValueOrDie();
305   HloCustomCallInstruction* conv;
306   TF_ASSERT_OK_AND_ASSIGN(conv, GetCustomCall(moduleFilterCoversInput.get(),
307                                               "__cudnn$convForward"));
308   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
309               IsOkAndHolds(true));
310   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
311               IsOkAndHolds(false));
312 
313   auto moduleFilterAlmostCoversInput = ParseAndReturnVerifiedModule(R"(
314   HloModule TestModule
315 
316   ENTRY TestComputation {
317     input = s8[32,3,3,64] parameter(0)
318     filter = s8[3,3,64,128] parameter(1)
319     ROOT result = (s8[32,3,3,128], u8[0]) custom-call(input, filter),
320                   window={size=3x3}, dim_labels=b01f_01io->b01f,
321                   custom_call_target="__cudnn$convForward"
322   })")
323                                            .ValueOrDie();
324   TF_ASSERT_OK_AND_ASSIGN(conv,
325                           GetCustomCall(moduleFilterAlmostCoversInput.get(),
326                                         "__cudnn$convForward"));
327   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 4),
328               IsOkAndHolds(true));
329   EXPECT_THAT(CudnnSupportsOptimizedIntegerConvolution({7, 5}, *conv, 32),
330               IsOkAndHolds(true));
331 }
332 
333 }  // namespace
334 }  // namespace gpu
335 }  // namespace xla
336