• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The CUDA-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
20 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
21 
22 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
23 #include "tensorflow/stream_executor/dnn.h"
24 #include "tensorflow/stream_executor/lib/status.h"
25 #include "tensorflow/stream_executor/platform/thread_annotations.h"
26 #include "tensorflow/stream_executor/plugin_registry.h"
27 #include "tensorflow/stream_executor/temporary_device_memory.h"
28 
29 namespace stream_executor {
30 namespace gpu {
31 
32 class GpuExecutor;
33 class CudnnRnnDescriptor;
34 class CudnnRnnSequenceTensorDescriptor;
35 class CudnnRnnStateTensorDescriptor;
36 class CudnnCtcLossDescriptor;
37 
38 // Opaque and unique identifier for the cuDNN plugin.
39 extern const PluginId kCuDnnPlugin;
40 
41 // cudnn-library based DNN support. For details on overridden interface
42 // functions, see dnn.h.
43 class CudnnSupport : public dnn::DnnSupport {
44  public:
45   explicit CudnnSupport(GpuExecutor* parent);
46 
47   port::Status Init() override;
48   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
49 
50   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
51       int num_layers, int hidden_size, int input_size, int cell_size,
52       int batch_size, dnn::RnnInputMode input_mode,
53       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
54       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
55       float dropout, uint64 seed, ScratchAllocator* state_allocator,
56       bool use_padded_io) override;
57 
58   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
59   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
60                                     int data_size,
61                                     dnn::DataType data_type) override;
62 
63   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
64   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
65                                     int data_size,
66                                     const absl::Span<const int>& seq_lengths,
67                                     bool time_major,
68                                     dnn::DataType data_type) override;
69 
70   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
71   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
72                                  dnn::DataType data_type) override;
73 
74   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
75                     const dnn::RnnSequenceTensorDescriptor& input_desc,
76                     const DeviceMemory<Eigen::half>& input_data,
77                     const dnn::RnnStateTensorDescriptor& input_h_desc,
78                     const DeviceMemory<Eigen::half>& input_h_data,
79                     const dnn::RnnStateTensorDescriptor& input_c_desc,
80                     const DeviceMemory<Eigen::half>& input_c_data,
81                     const DeviceMemory<Eigen::half>& params,
82                     const dnn::RnnSequenceTensorDescriptor& output_desc,
83                     DeviceMemory<Eigen::half>* output_data,
84                     const dnn::RnnStateTensorDescriptor& output_h_desc,
85                     DeviceMemory<Eigen::half>* output_h_data,
86                     const dnn::RnnStateTensorDescriptor& output_c_desc,
87                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
88                     ScratchAllocator* reserve_space_allocator,
89                     ScratchAllocator* workspace_allocator,
90                     dnn::ProfileResult* output_profile_result) override;
91 
92   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
93                     const dnn::RnnSequenceTensorDescriptor& input_desc,
94                     const DeviceMemory<float>& input_data,
95                     const dnn::RnnStateTensorDescriptor& input_h_desc,
96                     const DeviceMemory<float>& input_h_data,
97                     const dnn::RnnStateTensorDescriptor& input_c_desc,
98                     const DeviceMemory<float>& input_c_data,
99                     const DeviceMemory<float>& params,
100                     const dnn::RnnSequenceTensorDescriptor& output_desc,
101                     DeviceMemory<float>* output_data,
102                     const dnn::RnnStateTensorDescriptor& output_h_desc,
103                     DeviceMemory<float>* output_h_data,
104                     const dnn::RnnStateTensorDescriptor& output_c_desc,
105                     DeviceMemory<float>* output_c_data, bool is_training,
106                     ScratchAllocator* reserve_space_allocator,
107                     ScratchAllocator* workspace_allocator,
108                     dnn::ProfileResult* output_profile_result) override;
109 
110   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
111                     const dnn::RnnSequenceTensorDescriptor& input_desc,
112                     const DeviceMemory<double>& input_data,
113                     const dnn::RnnStateTensorDescriptor& input_h_desc,
114                     const DeviceMemory<double>& input_h_data,
115                     const dnn::RnnStateTensorDescriptor& input_c_desc,
116                     const DeviceMemory<double>& input_c_data,
117                     const DeviceMemory<double>& params,
118                     const dnn::RnnSequenceTensorDescriptor& output_desc,
119                     DeviceMemory<double>* output_data,
120                     const dnn::RnnStateTensorDescriptor& output_h_desc,
121                     DeviceMemory<double>* output_h_data,
122                     const dnn::RnnStateTensorDescriptor& output_c_desc,
123                     DeviceMemory<double>* output_c_data, bool is_training,
124                     ScratchAllocator* reserve_space_allocator,
125                     ScratchAllocator* workspace_allocator,
126                     dnn::ProfileResult* output_profile_result) override;
127 
128   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
129                      const dnn::RnnSequenceTensorDescriptor& input_desc,
130                      const DeviceMemory<Eigen::half>& input_data,
131                      const dnn::RnnStateTensorDescriptor& input_h_desc,
132                      const DeviceMemory<Eigen::half>& input_h_data,
133                      const dnn::RnnStateTensorDescriptor& input_c_desc,
134                      const DeviceMemory<Eigen::half>& input_c_data,
135                      const DeviceMemory<Eigen::half>& params,
136                      const dnn::RnnSequenceTensorDescriptor& output_desc,
137                      const DeviceMemory<Eigen::half>& output_data,
138                      const dnn::RnnStateTensorDescriptor& output_h_desc,
139                      const DeviceMemory<Eigen::half>& output_h_data,
140                      const dnn::RnnStateTensorDescriptor& output_c_desc,
141                      const DeviceMemory<Eigen::half>& output_c_data,
142                      const DeviceMemory<Eigen::half>& output_backprop_data,
143                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
144                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
145                      DeviceMemory<Eigen::half>* input_backprop_data,
146                      DeviceMemory<Eigen::half>* input_h_backprop_data,
147                      DeviceMemory<Eigen::half>* input_c_backprop_data,
148                      DeviceMemory<Eigen::half>* params_backprop_data,
149                      DeviceMemory<uint8>* reserve_space_data,
150                      ScratchAllocator* workspace_allocator,
151                      dnn::ProfileResult* output_profile_result) override;
152 
153   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
154                      const dnn::RnnSequenceTensorDescriptor& input_desc,
155                      const DeviceMemory<float>& input_data,
156                      const dnn::RnnStateTensorDescriptor& input_h_desc,
157                      const DeviceMemory<float>& input_h_data,
158                      const dnn::RnnStateTensorDescriptor& input_c_desc,
159                      const DeviceMemory<float>& input_c_data,
160                      const DeviceMemory<float>& params,
161                      const dnn::RnnSequenceTensorDescriptor& output_desc,
162                      const DeviceMemory<float>& output_data,
163                      const dnn::RnnStateTensorDescriptor& output_h_desc,
164                      const DeviceMemory<float>& output_h_data,
165                      const dnn::RnnStateTensorDescriptor& output_c_desc,
166                      const DeviceMemory<float>& output_c_data,
167                      const DeviceMemory<float>& output_backprop_data,
168                      const DeviceMemory<float>& output_h_backprop_data,
169                      const DeviceMemory<float>& output_c_backprop_data,
170                      DeviceMemory<float>* input_backprop_data,
171                      DeviceMemory<float>* input_h_backprop_data,
172                      DeviceMemory<float>* input_c_backprop_data,
173                      DeviceMemory<float>* params_backprop_data,
174                      DeviceMemory<uint8>* reserve_space_data,
175                      ScratchAllocator* workspace_allocator,
176                      dnn::ProfileResult* output_profile_result) override;
177 
178   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
179                      const dnn::RnnSequenceTensorDescriptor& input_desc,
180                      const DeviceMemory<double>& input_data,
181                      const dnn::RnnStateTensorDescriptor& input_h_desc,
182                      const DeviceMemory<double>& input_h_data,
183                      const dnn::RnnStateTensorDescriptor& input_c_desc,
184                      const DeviceMemory<double>& input_c_data,
185                      const DeviceMemory<double>& params,
186                      const dnn::RnnSequenceTensorDescriptor& output_desc,
187                      const DeviceMemory<double>& output_data,
188                      const dnn::RnnStateTensorDescriptor& output_h_desc,
189                      const DeviceMemory<double>& output_h_data,
190                      const dnn::RnnStateTensorDescriptor& output_c_desc,
191                      const DeviceMemory<double>& output_c_data,
192                      const DeviceMemory<double>& output_backprop_data,
193                      const DeviceMemory<double>& output_h_backprop_data,
194                      const DeviceMemory<double>& output_c_backprop_data,
195                      DeviceMemory<double>* input_backprop_data,
196                      DeviceMemory<double>* input_h_backprop_data,
197                      DeviceMemory<double>* input_c_backprop_data,
198                      DeviceMemory<double>* params_backprop_data,
199                      DeviceMemory<uint8>* reserve_space_data,
200                      ScratchAllocator* workspace_allocator,
201                      dnn::ProfileResult* output_profile_result) override;
202 
203   bool GetConvolveAlgorithms(
204       bool with_winograd_nonfused, int cc_major, int cc_minor,
205       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
206 
207   bool GetRnnAlgorithms(
208       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
209 
210   bool GetConvolveBackwardDataAlgorithms(
211       bool with_winograd_nonfused, int cc_major, int cc_minor,
212       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
213 
214   bool GetConvolveBackwardFilterAlgorithms(
215       bool with_winograd_nonfused, int cc_major, int cc_minor,
216       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
217 
218   bool DoBatchNormalizationForward(
219       Stream* stream, const DeviceMemory<float>& x,
220       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
221       const DeviceMemory<float>& estimated_mean,
222       const DeviceMemory<float>& estimated_variance,
223       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
224       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
225       const double exponential_average_factor,
226       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
227       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
228       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
229       bool is_training, ScratchAllocator* reserve_space_allocator,
230       ScratchAllocator* workspace_allocator,
231       std::function<const DeviceMemory<float>&()> var_to_inv_var,
232       std::function<void()> inv_var_to_var) override;
233 
234   bool DoBatchNormalizationForward(
235       Stream* stream, const DeviceMemory<Eigen::half>& x,
236       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
237       const DeviceMemory<float>& estimated_mean,
238       const DeviceMemory<float>& estimated_variance,
239       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
240       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
241       const double exponential_average_factor,
242       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
243       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
244       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
245       bool is_training, ScratchAllocator* reserve_space_allocator,
246       ScratchAllocator* workspace_allocator,
247       std::function<const DeviceMemory<float>&()> var_to_inv_var,
248       std::function<void()> inv_var_to_var) override;
249 
250   bool DoBatchNormalizationBackward(
251       Stream* stream, const DeviceMemory<float>& y_backprop,
252       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
253       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
254       const dnn::BatchDescriptor& x_desc,
255       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
256       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
257       DeviceMemory<float>* offset_backprop,
258       DeviceMemory<uint8>* reserve_space_data,
259       ScratchAllocator* workspace_allocator) override;
260 
261   bool DoBatchNormalizationBackward(
262       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
263       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
264       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
265       const dnn::BatchDescriptor& x_desc,
266       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
267       DeviceMemory<Eigen::half>* x_backprop,
268       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
269       DeviceMemory<uint8>* reserve_space_data,
270       ScratchAllocator* workspace_allocator) override;
271 
272   port::Status DoConvolve(
273       dnn::ConvolutionKind kind, dnn::DataType element_type,
274       dnn::DataType output_type, Stream* stream,
275       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
276       const dnn::FilterDescriptor& filter_descriptor,
277       DeviceMemoryBase filter_data,
278       const dnn::BatchDescriptor& output_descriptor,
279       DeviceMemoryBase output_data,
280       const dnn::ConvolutionDescriptor& convolution_descriptor,
281       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
282       dnn::ProfileResult* output_profile_result) override;
283 
284   bool DoFusedConvolve(
285       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
286       const DeviceMemory<double>& conv_input_data, double conv_input_scale,
287       const dnn::FilterDescriptor& filter_descriptor,
288       const DeviceMemory<double>& filter_data,
289       const dnn::ConvolutionDescriptor& convolution_descriptor,
290       const DeviceMemory<double>& side_input_data, double side_input_scale,
291       const dnn::BatchDescriptor& bias_descriptor,
292       const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
293       const dnn::BatchDescriptor& output_descriptor,
294       DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
295       const dnn::AlgorithmConfig& algorithm_config,
296       dnn::ProfileResult* output_profile_result) override;
297 
298   bool DoFusedConvolve(
299       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
300       const DeviceMemory<float>& conv_input_data, float conv_input_scale,
301       const dnn::FilterDescriptor& filter_descriptor,
302       const DeviceMemory<float>& filter_data,
303       const dnn::ConvolutionDescriptor& convolution_descriptor,
304       const DeviceMemory<float>& side_input_data, float side_input_scale,
305       const dnn::BatchDescriptor& bias_descriptor,
306       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
307       const dnn::BatchDescriptor& output_descriptor,
308       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
309       const dnn::AlgorithmConfig& algorithm_config,
310       dnn::ProfileResult* output_profile_result) override;
311 
312   bool DoFusedConvolve(Stream* stream,
313                        const dnn::BatchDescriptor& conv_input_descriptor,
314                        const DeviceMemory<Eigen::half>& conv_input_data,
315                        float conv_input_scale,
316                        const dnn::FilterDescriptor& filter_descriptor,
317                        const DeviceMemory<Eigen::half>& filter_data,
318                        const dnn::ConvolutionDescriptor& convolution_descriptor,
319                        const DeviceMemory<Eigen::half>& side_input_data,
320                        float side_input_scale,
321                        const dnn::BatchDescriptor& bias_descriptor,
322                        const DeviceMemory<Eigen::half>& biases,
323                        dnn::ActivationMode activation_mode,
324                        const dnn::BatchDescriptor& output_descriptor,
325                        DeviceMemory<Eigen::half>* output_data,
326                        ScratchAllocator* scratch_allocator,
327                        const dnn::AlgorithmConfig& algorithm_config,
328                        dnn::ProfileResult* output_profile_result) override;
329 
330   bool DoFusedConvolve(
331       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
332       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
333       const dnn::FilterDescriptor& filter_descriptor,
334       const DeviceMemory<int8>& filter_data,
335       const dnn::ConvolutionDescriptor& convolution_descriptor,
336       const DeviceMemory<int8>& side_input_data, float side_input_scale,
337       const dnn::BatchDescriptor& bias_descriptor,
338       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
339       const dnn::BatchDescriptor& output_descriptor,
340       DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
341       const dnn::AlgorithmConfig& algorithm_config,
342       dnn::ProfileResult* output_profile_result) override;
343 
344   bool DoFusedConvolve(
345       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
346       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
347       const dnn::FilterDescriptor& filter_descriptor,
348       const DeviceMemory<int8>& filter_data,
349       const dnn::ConvolutionDescriptor& convolution_descriptor,
350       const DeviceMemory<float>& side_input_data, float side_input_scale,
351       const dnn::BatchDescriptor& bias_descriptor,
352       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
353       const dnn::BatchDescriptor& output_descriptor,
354       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
355       const dnn::AlgorithmConfig& algorithm_config,
356       dnn::ProfileResult* output_profile_result) override;
357 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)358   bool DoConvolveQuantized(
359       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
360       const DeviceMemory<float>& input_data,
361       const dnn::FilterDescriptor& filter_descriptor,
362       const DeviceMemory<int8>& filter_coefficients,
363       const DeviceMemory<float>& coefficient_scales,
364       const dnn::ConvolutionDescriptor& convolution_descriptor,
365       const dnn::BatchDescriptor& output_descriptor,
366       DeviceMemory<float>* output_data) override {
367     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
368     return false;
369   }
370 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)371   bool DoConvolveQuantized(
372       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
373       const DeviceMemory<float>& input_data,
374       const dnn::FilterDescriptor& filter_descriptor,
375       const DeviceMemory<int16>& filter_coefficients,
376       const DeviceMemory<float>& coefficient_scales,
377       const dnn::ConvolutionDescriptor& convolution_descriptor,
378       const dnn::BatchDescriptor& output_descriptor,
379       DeviceMemory<float>* output_data) override {
380     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
381     return false;
382   }
383 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)384   bool DoSeparableConvolve(
385       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
386       const DeviceMemory<float>& input_data,
387       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
388       const DeviceMemory<float>& first_weights,
389       const DeviceMemory<float>& second_weights,
390       const dnn::ConvolutionDescriptor& convolution_descriptor,
391       const dnn::BatchDescriptor& output_descriptor,
392       DeviceMemory<float>* output_data) override {
393     LOG(ERROR) << "separable convolution not supported by CUDNN";
394     return false;
395   }
396 
397   bool DoConvolveBackwardBias(
398       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
399       const DeviceMemory<double>& input_data,
400       const dnn::BatchDescriptor& bias_descriptor,
401       DeviceMemory<double>* backward_bias_data) override;
402 
403   bool DoConvolveBackwardBias(Stream* stream,
404                               const dnn::BatchDescriptor& input_descriptor,
405                               const DeviceMemory<float>& input_data,
406                               const dnn::BatchDescriptor& bias_descriptor,
407                               DeviceMemory<float>* backward_bias_data) override;
408 
409   bool DoConvolveBackwardBias(
410       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
411       const DeviceMemory<Eigen::half>& input_data,
412       const dnn::BatchDescriptor& bias_descriptor,
413       DeviceMemory<Eigen::half>* backward_bias_data) override;
414 
415   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
416                 const DeviceMemory<float>& weights,
417                 const dnn::BatchDescriptor& input_dimensions,
418                 const dnn::BatchDescriptor& output_dimensions,
419                 DeviceMemory<float>* output_data) override;
420 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)421   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
422                          const DeviceMemory<int8>& quantized_weights,
423                          const DeviceMemory<float>& weight_scales,
424                          const dnn::BatchDescriptor& input_dimensions,
425                          const dnn::BatchDescriptor& output_dimensions,
426                          DeviceMemory<float>* output_data) override {
427     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
428     return false;
429   }
430 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)431   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
432                          const DeviceMemory<int16>& quantized_weights,
433                          const DeviceMemory<float>& weight_scales,
434                          const dnn::BatchDescriptor& input_dimensions,
435                          const dnn::BatchDescriptor& output_dimensions,
436                          DeviceMemory<float>* output_data) override {
437     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
438     return false;
439   }
440 
441   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
442                  const DeviceMemory<float>& biases,
443                  const dnn::BatchDescriptor& dimensions,
444                  DeviceMemory<float>* output_data) override;
445 
446   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
447                   const dnn::BatchDescriptor& dimensions,
448                   const DeviceMemory<float>& input_data,
449                   DeviceMemory<float>* output_data, uint64 options) override;
450 
451   bool DoPoolForward(Stream* stream,
452                      const dnn::PoolingDescriptor& pooling_dimensions,
453                      const dnn::BatchDescriptor& input_dimensions,
454                      const DeviceMemory<double>& input_data,
455                      const dnn::BatchDescriptor& output_dimensions,
456                      DeviceMemory<double>* output_data,
457                      ScratchAllocator* workspace_allocator) override;
458 
459   bool DoPoolForward(Stream* stream,
460                      const dnn::PoolingDescriptor& pooling_dimensions,
461                      const dnn::BatchDescriptor& input_dimensions,
462                      const DeviceMemory<float>& input_data,
463                      const dnn::BatchDescriptor& output_dimensions,
464                      DeviceMemory<float>* output_data,
465                      ScratchAllocator* workspace_allocator) override;
466 
467   bool DoPoolForward(Stream* stream,
468                      const dnn::PoolingDescriptor& pooling_dimensions,
469                      const dnn::BatchDescriptor& input_dimensions,
470                      const DeviceMemory<Eigen::half>& input_data,
471                      const dnn::BatchDescriptor& output_dimensions,
472                      DeviceMemory<Eigen::half>* output_data,
473                      ScratchAllocator* workspace_allocator) override;
474 
475   bool DoPoolForward(Stream* stream,
476                      const dnn::PoolingDescriptor& pooling_dimensions,
477                      const dnn::BatchDescriptor& input_dimensions,
478                      const DeviceMemory<int8>& input_data,
479                      const dnn::BatchDescriptor& output_dimensions,
480                      DeviceMemory<int8>* output_data,
481                      ScratchAllocator* workspace_allocator) override;
482 
483   bool DoPoolBackward(Stream* stream,
484                       const dnn::PoolingDescriptor& pooling_dimensions,
485                       const dnn::BatchDescriptor& input_dimensions,
486                       const DeviceMemory<double>& input_data,
487                       const dnn::BatchDescriptor& output_dimensions,
488                       const DeviceMemory<double>& output_data,
489                       const DeviceMemory<double>& input_diff_data,
490                       DeviceMemory<double>* output_diff_data,
491                       ScratchAllocator* workspace_allocator) override;
492 
493   bool DoPoolBackward(Stream* stream,
494                       const dnn::PoolingDescriptor& pooling_dimensions,
495                       const dnn::BatchDescriptor& input_dimensions,
496                       const DeviceMemory<float>& input_data,
497                       const dnn::BatchDescriptor& output_dimensions,
498                       const DeviceMemory<float>& output_data,
499                       const DeviceMemory<float>& input_diff_data,
500                       DeviceMemory<float>* output_diff_data,
501                       ScratchAllocator* workspace_allocator) override;
502 
503   bool DoPoolBackward(Stream* stream,
504                       const dnn::PoolingDescriptor& pooling_dimensions,
505                       const dnn::BatchDescriptor& input_dimensions,
506                       const DeviceMemory<Eigen::half>& input_data,
507                       const dnn::BatchDescriptor& output_dimensions,
508                       const DeviceMemory<Eigen::half>& output_data,
509                       const DeviceMemory<Eigen::half>& input_diff_data,
510                       DeviceMemory<Eigen::half>* output_diff_data,
511                       ScratchAllocator* workspace_allocator) override;
512 
513   bool DoNormalizeWithDimensions(
514       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
515       const dnn::BatchDescriptor& dimensions,
516       const DeviceMemory<float>& input_data,
517       DeviceMemory<float>* output_data) override;
518 
519   bool DoNormalizeBackwardWithDimensions(
520       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
521       const dnn::BatchDescriptor& dimensions,
522       const DeviceMemory<float>& raw_data,
523       const DeviceMemory<float>& normalized_data,
524       const DeviceMemory<float>& normalized_variable_gradient,
525       DeviceMemory<float>* raw_variable_gradient,
526       ScratchAllocator* workspace_allocator) override;
527 
528   bool DoDepthConcatenate(
529       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
530       port::ArraySlice<const DeviceMemory<float>*> input_data,
531       DeviceMemory<float>* output_data) override;
532 
533   bool DoElementwiseOperate(
534       Stream* stream, dnn::ElementwiseOperation operation,
535       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
536       port::ArraySlice<const DeviceMemory<float>*> input_data,
537       const dnn::BatchDescriptor& output_dimensions,
538       DeviceMemory<float>* output_data) override;
539 
540   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
541                const DeviceMemory<float>& input_data, int64 left_pad,
542                int64 right_pad, int64 top_pad, int64 bottom_pad,
543                DeviceMemory<float>* output_data) override;
544 
545   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
546                  const DeviceMemory<float>& input_data, int64 left_trim,
547                  int64 right_trim, int64 top_trim, int64 bottom_trim,
548                  DeviceMemory<float>* output_data) override;
549 
550   bool DoMemcpyD2HQuantized(Stream* stream,
551                             const DeviceMemory<float>& device_unquantized_src,
552                             dnn::QuantizedActivationMode mode, void* host_dst,
553                             int64 size) override;
554 
555   bool DoMemcpyH2DQuantized(
556       Stream* stream, const void* host_src, int64 size,
557       dnn::QuantizedActivationMode mode,
558       DeviceMemory<float>* device_unquantized_dst) override;
559 
560   // Derives an output batch descriptor from an input batch and convolution
561   // descriptors.
562   bool DeriveOutputBatchDescriptor(
563       const dnn::BatchDescriptor& batch_descriptor,
564       const dnn::FilterDescriptor& filter_descriptor,
565       const dnn::ConvolutionDescriptor& convolution_descriptor,
566       dnn::BatchDescriptor* output_batch_descriptor);
567 
568   port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
569                          const dnn::RnnStateTensorDescriptor& probs_desc,
570                          const DeviceMemoryBase probs_data,
571                          absl::Span<const int> labels_data,
572                          absl::Span<const int> labels_lengths_data,
573                          absl::Span<const int> input_lengths_data,
574                          DeviceMemoryBase costs_data,
575                          const dnn::RnnStateTensorDescriptor& grads_desc,
576                          DeviceMemoryBase grads_data,
577                          DeviceMemory<uint8> scratch_memory) override;
578 
579   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
580                          dnn::DataType input_type,
581                          const DeviceMemoryBase& input_data,
582                          const dnn::BatchDescriptor& output_desc,
583                          dnn::DataType output_type, float scale,
584                          DeviceMemoryBase* output_data) override;
585 
586  private:
587   GpuExecutor* parent_;  // Parent executor object. Not owned.
588 
589   // Provides access to the cuDNN handle.
590   std::unique_ptr<class CudnnAccess> cudnn_;
591 
592   template <class T, class U>
593   port::Status DoBatchNormalizationForwardImpl(
594       Stream* stream, dnn::DataType input_data_type,
595       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
596       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
597       const DeviceMemory<U>& estimated_mean,
598       const DeviceMemory<U>& estimated_variance,
599       const DeviceMemory<U>& side_input, const dnn::BatchDescriptor& x_desc,
600       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
601       const double exponential_average_factor,
602       dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
603       DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
604       DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
605       bool is_training, ScratchAllocator* reserve_space_allocator,
606       ScratchAllocator* workspace_allocator,
607       std::function<const DeviceMemory<U>&()> var_to_inv_var,
608       std::function<void()> inv_var_to_var);
609 
610   template <class T, class U>
611   port::Status DoBatchNormalizationBackwardImpl(
612       Stream* stream, int cudnn_input_type, int cudnn_scale_type,
613       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
614       const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
615       const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
616       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
617       DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
618       DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
619       ScratchAllocator* workspace_allocator);
620 
621   template <typename ElementType, typename BiasType, typename ScaleType,
622             typename OutputType>
623   port::Status DoFusedConvolveImpl(
624       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
625       const DeviceMemory<ElementType>& conv_input_data,
626       ScaleType conv_input_scale,
627       const dnn::FilterDescriptor& filter_descriptor,
628       const DeviceMemory<ElementType>& filter_data,
629       const dnn::ConvolutionDescriptor& convolution_descriptor,
630       const DeviceMemory<OutputType>& side_input_data,
631       ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
632       const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
633       const dnn::BatchDescriptor& output_descriptor,
634       DeviceMemory<OutputType>* output_data, dnn::DataType accumulator_type,
635       ScratchAllocator* scratch_allocator,
636       const dnn::AlgorithmConfig& algorithm_config,
637       dnn::ProfileResult* output_profile_result);
638 
639   template <class T>
640   port::Status DoConvolveBackwardBiasImpl(
641       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
642       const DeviceMemory<T>& input_data,
643       const dnn::BatchDescriptor& bias_descriptor,
644       DeviceMemory<T>* backward_bias_data);
645 
646   template <class T>
647   port::Status DoRnnForwardImpl(
648       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
649       const CudnnRnnSequenceTensorDescriptor& input_desc,
650       const DeviceMemory<T>& input_data,
651       const CudnnRnnStateTensorDescriptor& input_h_desc,
652       const DeviceMemory<T>& input_h_data,
653       const CudnnRnnStateTensorDescriptor& input_c_desc,
654       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
655       const CudnnRnnSequenceTensorDescriptor& output_desc,
656       DeviceMemory<T>* output_data,
657       const CudnnRnnStateTensorDescriptor& output_h_desc,
658       DeviceMemory<T>* output_h_data,
659       const CudnnRnnStateTensorDescriptor& output_c_desc,
660       DeviceMemory<T>* output_c_data, bool is_training,
661       ScratchAllocator* reserve_space_allocator,
662       ScratchAllocator* workspace_allocator,
663       dnn::ProfileResult* output_profile_result);
664 
665   template <class T>
666   port::Status DoRnnBackwardImpl(
667       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
668       const CudnnRnnSequenceTensorDescriptor& input_desc,
669       const DeviceMemory<T>& input_data,
670       const CudnnRnnStateTensorDescriptor& input_h_desc,
671       const DeviceMemory<T>& input_h_data,
672       const CudnnRnnStateTensorDescriptor& input_c_desc,
673       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
674       const CudnnRnnSequenceTensorDescriptor& output_desc,
675       const DeviceMemory<T>& output_data,
676       const CudnnRnnStateTensorDescriptor& output_h_desc,
677       const DeviceMemory<T>& output_h_data,
678       const CudnnRnnStateTensorDescriptor& output_c_desc,
679       const DeviceMemory<T>& output_c_data,
680       const DeviceMemory<T>& output_backprop_data,
681       const DeviceMemory<T>& output_h_backprop_data,
682       const DeviceMemory<T>& output_c_backprop_data,
683       DeviceMemory<T>* input_backprop_data,
684       DeviceMemory<T>* input_h_backprop_data,
685       DeviceMemory<T>* input_c_backprop_data,
686       DeviceMemory<T>* params_backprop_data,
687       DeviceMemory<uint8>* reserve_space_data,
688       ScratchAllocator* workspace_allocator,
689       dnn::ProfileResult* output_profile_result);
690 
691   port::Status DoCtcLossImpl(
692       Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
693       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
694       absl::Span<const int> labels_lengths_data,
695       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
696       const CudnnRnnStateTensorDescriptor& grads_desc,
697       DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
698       DeviceMemory<uint8> scratch_memory);
699 
700  private:
701   port::Status DoPrepareForConvolution(
702       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
703       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
704       const dnn::FilterDescriptor& filter_descriptor,
705       DeviceMemoryBase filter_data,
706       const dnn::BatchDescriptor& output_descriptor,
707       DeviceMemoryBase output_data,
708       const dnn::ConvolutionDescriptor& convolution_descriptor,
709       const dnn::AlgorithmConfig& algorithm_config,
710       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
711       DeviceMemory<uint8>* scratch_memory) override;
712 
713   port::Status DoPrepareForCtcLoss(
714       Stream* stream, dnn::DataType element_type,
715       const dnn::RnnStateTensorDescriptor& probs_desc,
716       const dnn::RnnStateTensorDescriptor& grads_desc,
717       absl::Span<const int> labels_data,
718       absl::Span<const int> labels_lengths_data,
719       absl::Span<const int> input_lengths_data,
720       ScratchAllocator* scratch_allocator,
721       DeviceMemory<uint8>* scratch_memory) override;
722 
723   SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
724 };
725 
726 }  // namespace gpu
727 }  // namespace stream_executor
728 
729 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
730