• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The CUDA-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
20 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
21 
22 #include "tensorflow/core/platform/thread_annotations.h"
23 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
24 #include "tensorflow/stream_executor/dnn.h"
25 #include "tensorflow/stream_executor/lib/status.h"
26 #include "tensorflow/stream_executor/plugin_registry.h"
27 #include "tensorflow/stream_executor/temporary_device_memory.h"
28 
29 namespace stream_executor {
30 namespace gpu {
31 
32 class GpuExecutor;
33 class CudnnRnnDescriptor;
34 class CudnnRnnSequenceTensorDescriptor;
35 class CudnnRnnStateTensorDescriptor;
36 class CudnnCtcLossDescriptor;
37 
38 // Opaque and unique identifier for the cuDNN plugin.
39 extern const PluginId kCuDnnPlugin;
40 
41 // cudnn-library based DNN support. For details on overridden interface
42 // functions, see dnn.h.
43 class CudnnSupport : public dnn::DnnSupport {
44  public:
45   explicit CudnnSupport(GpuExecutor* parent);
46 
47   port::Status Init() override;
48   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
49 
50   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
51       int num_layers, int hidden_size, int input_size, int cell_size,
52       int batch_size, dnn::RnnInputMode input_mode,
53       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
54       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
55       float dropout, uint64 seed, ScratchAllocator* state_allocator,
56       bool use_padded_io) override;
57 
58   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
59   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
60                                     int data_size,
61                                     dnn::DataType data_type) override;
62 
63   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
64   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
65                                     int data_size,
66                                     const absl::Span<const int>& seq_lengths,
67                                     bool time_major,
68                                     dnn::DataType data_type) override;
69 
70   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
71   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
72                                  dnn::DataType data_type) override;
73 
74   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
75                     const dnn::RnnSequenceTensorDescriptor& input_desc,
76                     const DeviceMemory<Eigen::half>& input_data,
77                     const DeviceMemory<int>& seq_lengths_data,
78                     const dnn::RnnStateTensorDescriptor& input_h_desc,
79                     const DeviceMemory<Eigen::half>& input_h_data,
80                     const dnn::RnnStateTensorDescriptor& input_c_desc,
81                     const DeviceMemory<Eigen::half>& input_c_data,
82                     const DeviceMemory<Eigen::half>& params,
83                     const dnn::RnnSequenceTensorDescriptor& output_desc,
84                     DeviceMemory<Eigen::half>* output_data,
85                     const dnn::RnnStateTensorDescriptor& output_h_desc,
86                     DeviceMemory<Eigen::half>* output_h_data,
87                     const dnn::RnnStateTensorDescriptor& output_c_desc,
88                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
89                     ScratchAllocator* reserve_space_allocator,
90                     ScratchAllocator* workspace_allocator,
91                     dnn::ProfileResult* output_profile_result) override;
92 
93   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
94                     const dnn::RnnSequenceTensorDescriptor& input_desc,
95                     const DeviceMemory<float>& input_data,
96                     const DeviceMemory<int>& seq_lengths_data,
97                     const dnn::RnnStateTensorDescriptor& input_h_desc,
98                     const DeviceMemory<float>& input_h_data,
99                     const dnn::RnnStateTensorDescriptor& input_c_desc,
100                     const DeviceMemory<float>& input_c_data,
101                     const DeviceMemory<float>& params,
102                     const dnn::RnnSequenceTensorDescriptor& output_desc,
103                     DeviceMemory<float>* output_data,
104                     const dnn::RnnStateTensorDescriptor& output_h_desc,
105                     DeviceMemory<float>* output_h_data,
106                     const dnn::RnnStateTensorDescriptor& output_c_desc,
107                     DeviceMemory<float>* output_c_data, bool is_training,
108                     ScratchAllocator* reserve_space_allocator,
109                     ScratchAllocator* workspace_allocator,
110                     dnn::ProfileResult* output_profile_result) override;
111 
112   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
113                     const dnn::RnnSequenceTensorDescriptor& input_desc,
114                     const DeviceMemory<double>& input_data,
115                     const DeviceMemory<int>& seq_lengths_data,
116                     const dnn::RnnStateTensorDescriptor& input_h_desc,
117                     const DeviceMemory<double>& input_h_data,
118                     const dnn::RnnStateTensorDescriptor& input_c_desc,
119                     const DeviceMemory<double>& input_c_data,
120                     const DeviceMemory<double>& params,
121                     const dnn::RnnSequenceTensorDescriptor& output_desc,
122                     DeviceMemory<double>* output_data,
123                     const dnn::RnnStateTensorDescriptor& output_h_desc,
124                     DeviceMemory<double>* output_h_data,
125                     const dnn::RnnStateTensorDescriptor& output_c_desc,
126                     DeviceMemory<double>* output_c_data, bool is_training,
127                     ScratchAllocator* reserve_space_allocator,
128                     ScratchAllocator* workspace_allocator,
129                     dnn::ProfileResult* output_profile_result) override;
130 
131   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
132                      const dnn::RnnSequenceTensorDescriptor& input_desc,
133                      const DeviceMemory<Eigen::half>& input_data,
134                      const DeviceMemory<int>& seq_lengths_data,
135                      const dnn::RnnStateTensorDescriptor& input_h_desc,
136                      const DeviceMemory<Eigen::half>& input_h_data,
137                      const dnn::RnnStateTensorDescriptor& input_c_desc,
138                      const DeviceMemory<Eigen::half>& input_c_data,
139                      const DeviceMemory<Eigen::half>& params,
140                      const dnn::RnnSequenceTensorDescriptor& output_desc,
141                      const DeviceMemory<Eigen::half>& output_data,
142                      const dnn::RnnStateTensorDescriptor& output_h_desc,
143                      const DeviceMemory<Eigen::half>& output_h_data,
144                      const dnn::RnnStateTensorDescriptor& output_c_desc,
145                      const DeviceMemory<Eigen::half>& output_c_data,
146                      const DeviceMemory<Eigen::half>& output_backprop_data,
147                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
148                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
149                      DeviceMemory<Eigen::half>* input_backprop_data,
150                      DeviceMemory<Eigen::half>* input_h_backprop_data,
151                      DeviceMemory<Eigen::half>* input_c_backprop_data,
152                      DeviceMemory<Eigen::half>* params_backprop_data,
153                      DeviceMemory<uint8>* reserve_space_data,
154                      ScratchAllocator* workspace_allocator,
155                      dnn::ProfileResult* output_profile_result) override;
156 
157   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
158                      const dnn::RnnSequenceTensorDescriptor& input_desc,
159                      const DeviceMemory<float>& input_data,
160                      const DeviceMemory<int>& seq_lengths_data,
161                      const dnn::RnnStateTensorDescriptor& input_h_desc,
162                      const DeviceMemory<float>& input_h_data,
163                      const dnn::RnnStateTensorDescriptor& input_c_desc,
164                      const DeviceMemory<float>& input_c_data,
165                      const DeviceMemory<float>& params,
166                      const dnn::RnnSequenceTensorDescriptor& output_desc,
167                      const DeviceMemory<float>& output_data,
168                      const dnn::RnnStateTensorDescriptor& output_h_desc,
169                      const DeviceMemory<float>& output_h_data,
170                      const dnn::RnnStateTensorDescriptor& output_c_desc,
171                      const DeviceMemory<float>& output_c_data,
172                      const DeviceMemory<float>& output_backprop_data,
173                      const DeviceMemory<float>& output_h_backprop_data,
174                      const DeviceMemory<float>& output_c_backprop_data,
175                      DeviceMemory<float>* input_backprop_data,
176                      DeviceMemory<float>* input_h_backprop_data,
177                      DeviceMemory<float>* input_c_backprop_data,
178                      DeviceMemory<float>* params_backprop_data,
179                      DeviceMemory<uint8>* reserve_space_data,
180                      ScratchAllocator* workspace_allocator,
181                      dnn::ProfileResult* output_profile_result) override;
182 
183   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
184                      const dnn::RnnSequenceTensorDescriptor& input_desc,
185                      const DeviceMemory<double>& input_data,
186                      const DeviceMemory<int>& seq_lengths_data,
187                      const dnn::RnnStateTensorDescriptor& input_h_desc,
188                      const DeviceMemory<double>& input_h_data,
189                      const dnn::RnnStateTensorDescriptor& input_c_desc,
190                      const DeviceMemory<double>& input_c_data,
191                      const DeviceMemory<double>& params,
192                      const dnn::RnnSequenceTensorDescriptor& output_desc,
193                      const DeviceMemory<double>& output_data,
194                      const dnn::RnnStateTensorDescriptor& output_h_desc,
195                      const DeviceMemory<double>& output_h_data,
196                      const dnn::RnnStateTensorDescriptor& output_c_desc,
197                      const DeviceMemory<double>& output_c_data,
198                      const DeviceMemory<double>& output_backprop_data,
199                      const DeviceMemory<double>& output_h_backprop_data,
200                      const DeviceMemory<double>& output_c_backprop_data,
201                      DeviceMemory<double>* input_backprop_data,
202                      DeviceMemory<double>* input_h_backprop_data,
203                      DeviceMemory<double>* input_c_backprop_data,
204                      DeviceMemory<double>* params_backprop_data,
205                      DeviceMemory<uint8>* reserve_space_data,
206                      ScratchAllocator* workspace_allocator,
207                      dnn::ProfileResult* output_profile_result) override;
208 
209   bool GetConvolveAlgorithms(
210       CudaComputeCapability cuda_compute_capability,
211       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
212 
213   bool GetConvolveExecutionPlans(
214       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
215       const dnn::BatchDescriptor& input_descriptor,
216       const dnn::FilterDescriptor& filter_descriptor,
217       const dnn::BatchDescriptor& output_descriptor,
218       const dnn::ConvolutionDescriptor& convolution_descriptor,
219       std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* out_exec_plans)
220       override;
221 
222   port::Status GetFusedConvolveExecutionPlans(
223       dnn::ConvolutionKind kind, dnn::DataType element_type,
224       double conv_input_scale, double side_input_scale, Stream* stream,
225       const dnn::BatchDescriptor& input_descriptor,
226       const dnn::FilterDescriptor& filter_descriptor,
227       const dnn::BatchDescriptor& bias_descriptor,
228       const dnn::BatchDescriptor& output_descriptor,
229       const dnn::ConvolutionDescriptor& convolution_descriptor,
230       dnn::ActivationMode activation_mode,
231       std::vector<std::unique_ptr<dnn::ConvolveExecutionPlan>>* out_exec_plans);
232 
233   bool GetRnnAlgorithms(
234       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
235 
236   bool GetConvolveBackwardDataAlgorithms(
237       CudaComputeCapability cuda_compute_capability,
238       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
239 
240   bool GetConvolveBackwardFilterAlgorithms(
241       CudaComputeCapability cuda_compute_capability,
242       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
243 
244   bool DoBatchNormalizationForward(
245       Stream* stream, const DeviceMemory<float>& x,
246       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
247       const DeviceMemory<float>& estimated_mean,
248       const DeviceMemory<float>& estimated_var_iance,
249       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
250       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
251       const double exponential_average_factor,
252       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
253       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
254       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
255       bool is_training, ScratchAllocator* reserve_space_allocator,
256       ScratchAllocator* workspace_allocator) override;
257 
258   bool DoBatchNormalizationForward(
259       Stream* stream, const DeviceMemory<Eigen::half>& x,
260       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
261       const DeviceMemory<float>& estimated_mean,
262       const DeviceMemory<float>& estimated_variance,
263       const DeviceMemory<Eigen::half>& side_input,
264       const dnn::BatchDescriptor& x_desc,
265       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
266       const double exponential_average_factor,
267       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
268       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
269       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
270       bool is_training, ScratchAllocator* reserve_space_allocator,
271       ScratchAllocator* workspace_allocator) override;
272 
273   bool DoBatchNormalizationBackward(
274       Stream* stream, const DeviceMemory<float>& y_backprop,
275       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
276       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
277       const dnn::BatchDescriptor& x_desc,
278       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
279       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
280       DeviceMemory<float>* offset_backprop,
281       DeviceMemory<uint8>* reserve_space_data,
282       ScratchAllocator* workspace_allocator) override;
283 
284   bool DoBatchNormalizationBackward(
285       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
286       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
287       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
288       const dnn::BatchDescriptor& x_desc,
289       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
290       DeviceMemory<Eigen::half>* x_backprop,
291       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
292       DeviceMemory<uint8>* reserve_space_data,
293       ScratchAllocator* workspace_allocator) override;
294 
295   port::Status DoConvolve(
296       dnn::ConvolutionKind kind, dnn::DataType element_type,
297       dnn::DataType output_type, Stream* stream,
298       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
299       const dnn::FilterDescriptor& filter_descriptor,
300       DeviceMemoryBase filter_data,
301       const dnn::BatchDescriptor& output_descriptor,
302       DeviceMemoryBase output_data,
303       const dnn::ConvolutionDescriptor& convolution_descriptor,
304       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
305       dnn::ProfileResult* output_profile_result) override;
306 
307   port::Status DoConvolveWithExecutionPlan(
308       dnn::ConvolutionKind kind, dnn::DataType element_type,
309       dnn::DataType output_type, Stream* stream,
310       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
311       const dnn::FilterDescriptor& filter_descriptor,
312       DeviceMemoryBase filter_data,
313       const dnn::BatchDescriptor& output_descriptor,
314       DeviceMemoryBase output_data,
315       const dnn::ConvolutionDescriptor& convolution_descriptor,
316       const dnn::AlgorithmConfig& plan_config,
317       ScratchAllocator* scratch_allocator,
318       dnn::ProfileResult* output_profile_result);
319 
320   port::Status DoFusedConvolve(
321       Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
322       dnn::DataType bias_type, dnn::DataType output_type,
323       const dnn::BatchDescriptor& conv_input_descriptor,
324       DeviceMemoryBase conv_input_data, double conv_input_scale,
325       const dnn::FilterDescriptor& filter_descriptor,
326       DeviceMemoryBase filter_data,
327       const dnn::ConvolutionDescriptor& convolution_descriptor,
328       DeviceMemoryBase side_input_data, double side_input_scale,
329       const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
330       dnn::ActivationMode activation_mode,
331       const dnn::BatchDescriptor& output_descriptor,
332       DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
333       const dnn::AlgorithmConfig& algorithm_config,
334       dnn::ProfileResult* output_profile_result) override;
335 
336   port::Status DoFusedConvolveWithExecutionPlan(
337       Stream* stream, dnn::DataType element_type,
338       const dnn::BatchDescriptor& conv_input_descriptor,
339       DeviceMemoryBase conv_input_data, double conv_input_scale,
340       const dnn::FilterDescriptor& filter_descriptor,
341       DeviceMemoryBase filter_data,
342       const dnn::ConvolutionDescriptor& convolution_descriptor,
343       DeviceMemoryBase side_input_data, double side_input_scale,
344       const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
345       dnn::ActivationMode activation_mode,
346       const dnn::BatchDescriptor& output_descriptor,
347       DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
348       const dnn::AlgorithmConfig& algorithm_config,
349       dnn::ProfileResult* output_profile_result);
350 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)351   bool DoConvolveQuantized(
352       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
353       const DeviceMemory<float>& input_data,
354       const dnn::FilterDescriptor& filter_descriptor,
355       const DeviceMemory<int8>& filter_coefficients,
356       const DeviceMemory<float>& coefficient_scales,
357       const dnn::ConvolutionDescriptor& convolution_descriptor,
358       const dnn::BatchDescriptor& output_descriptor,
359       DeviceMemory<float>* output_data) override {
360     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
361     return false;
362   }
363 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)364   bool DoConvolveQuantized(
365       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
366       const DeviceMemory<float>& input_data,
367       const dnn::FilterDescriptor& filter_descriptor,
368       const DeviceMemory<int16>& filter_coefficients,
369       const DeviceMemory<float>& coefficient_scales,
370       const dnn::ConvolutionDescriptor& convolution_descriptor,
371       const dnn::BatchDescriptor& output_descriptor,
372       DeviceMemory<float>* output_data) override {
373     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
374     return false;
375   }
376 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)377   bool DoSeparableConvolve(
378       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
379       const DeviceMemory<float>& input_data,
380       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
381       const DeviceMemory<float>& first_weights,
382       const DeviceMemory<float>& second_weights,
383       const dnn::ConvolutionDescriptor& convolution_descriptor,
384       const dnn::BatchDescriptor& output_descriptor,
385       DeviceMemory<float>* output_data) override {
386     LOG(ERROR) << "separable convolution not supported by CUDNN";
387     return false;
388   }
389 
390   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
391                 const DeviceMemory<float>& weights,
392                 const dnn::BatchDescriptor& input_dimensions,
393                 const dnn::BatchDescriptor& output_dimensions,
394                 DeviceMemory<float>* output_data) override;
395 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)396   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
397                          const DeviceMemory<int8>& quantized_weights,
398                          const DeviceMemory<float>& weight_scales,
399                          const dnn::BatchDescriptor& input_dimensions,
400                          const dnn::BatchDescriptor& output_dimensions,
401                          DeviceMemory<float>* output_data) override {
402     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
403     return false;
404   }
405 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)406   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
407                          const DeviceMemory<int16>& quantized_weights,
408                          const DeviceMemory<float>& weight_scales,
409                          const dnn::BatchDescriptor& input_dimensions,
410                          const dnn::BatchDescriptor& output_dimensions,
411                          DeviceMemory<float>* output_data) override {
412     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
413     return false;
414   }
415 
416   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
417                  const DeviceMemory<float>& biases,
418                  const dnn::BatchDescriptor& dimensions,
419                  DeviceMemory<float>* output_data) override;
420 
421   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
422                   const dnn::BatchDescriptor& dimensions,
423                   const DeviceMemory<float>& input_data,
424                   DeviceMemory<float>* output_data, uint64 options) override;
425 
426   bool DoPoolForward(Stream* stream,
427                      const dnn::PoolingDescriptor& pooling_dimensions,
428                      const dnn::BatchDescriptor& input_dimensions,
429                      const DeviceMemory<double>& input_data,
430                      const dnn::BatchDescriptor& output_dimensions,
431                      DeviceMemory<double>* output_data,
432                      ScratchAllocator* workspace_allocator) override;
433 
434   bool DoPoolForward(Stream* stream,
435                      const dnn::PoolingDescriptor& pooling_dimensions,
436                      const dnn::BatchDescriptor& input_dimensions,
437                      const DeviceMemory<float>& input_data,
438                      const dnn::BatchDescriptor& output_dimensions,
439                      DeviceMemory<float>* output_data,
440                      ScratchAllocator* workspace_allocator) override;
441 
442   bool DoPoolForward(Stream* stream,
443                      const dnn::PoolingDescriptor& pooling_dimensions,
444                      const dnn::BatchDescriptor& input_dimensions,
445                      const DeviceMemory<Eigen::half>& input_data,
446                      const dnn::BatchDescriptor& output_dimensions,
447                      DeviceMemory<Eigen::half>* output_data,
448                      ScratchAllocator* workspace_allocator) override;
449 
450   bool DoPoolForward(Stream* stream,
451                      const dnn::PoolingDescriptor& pooling_dimensions,
452                      const dnn::BatchDescriptor& input_dimensions,
453                      const DeviceMemory<int8>& input_data,
454                      const dnn::BatchDescriptor& output_dimensions,
455                      DeviceMemory<int8>* output_data,
456                      ScratchAllocator* workspace_allocator) override;
457 
458   bool DoPoolBackward(Stream* stream,
459                       const dnn::PoolingDescriptor& pooling_dimensions,
460                       const dnn::BatchDescriptor& input_dimensions,
461                       const DeviceMemory<double>& input_data,
462                       const dnn::BatchDescriptor& output_dimensions,
463                       const DeviceMemory<double>& output_data,
464                       const DeviceMemory<double>& input_diff_data,
465                       DeviceMemory<double>* output_diff_data,
466                       ScratchAllocator* workspace_allocator) override;
467 
468   bool DoPoolBackward(Stream* stream,
469                       const dnn::PoolingDescriptor& pooling_dimensions,
470                       const dnn::BatchDescriptor& input_dimensions,
471                       const DeviceMemory<float>& input_data,
472                       const dnn::BatchDescriptor& output_dimensions,
473                       const DeviceMemory<float>& output_data,
474                       const DeviceMemory<float>& input_diff_data,
475                       DeviceMemory<float>* output_diff_data,
476                       ScratchAllocator* workspace_allocator) override;
477 
478   bool DoPoolBackward(Stream* stream,
479                       const dnn::PoolingDescriptor& pooling_dimensions,
480                       const dnn::BatchDescriptor& input_dimensions,
481                       const DeviceMemory<Eigen::half>& input_data,
482                       const dnn::BatchDescriptor& output_dimensions,
483                       const DeviceMemory<Eigen::half>& output_data,
484                       const DeviceMemory<Eigen::half>& input_diff_data,
485                       DeviceMemory<Eigen::half>* output_diff_data,
486                       ScratchAllocator* workspace_allocator) override;
487 
488   bool DoNormalizeWithDimensions(
489       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
490       const dnn::BatchDescriptor& dimensions,
491       const DeviceMemory<float>& input_data,
492       DeviceMemory<float>* output_data) override;
493 
494   bool DoNormalizeBackwardWithDimensions(
495       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
496       const dnn::BatchDescriptor& dimensions,
497       const DeviceMemory<float>& raw_data,
498       const DeviceMemory<float>& normalized_data,
499       const DeviceMemory<float>& normalized_variable_gradient,
500       DeviceMemory<float>* raw_variable_gradient,
501       ScratchAllocator* workspace_allocator) override;
502 
503   bool DoDepthConcatenate(
504       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
505       port::ArraySlice<const DeviceMemory<float>*> input_data,
506       DeviceMemory<float>* output_data) override;
507 
508   bool DoElementwiseOperate(
509       Stream* stream, dnn::ElementwiseOperation operation,
510       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
511       port::ArraySlice<const DeviceMemory<float>*> input_data,
512       const dnn::BatchDescriptor& output_dimensions,
513       DeviceMemory<float>* output_data) override;
514 
515   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
516                const DeviceMemory<float>& input_data, int64_t left_pad,
517                int64_t right_pad, int64_t top_pad, int64_t bottom_pad,
518                DeviceMemory<float>* output_data) override;
519 
520   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
521                  const DeviceMemory<float>& input_data, int64_t left_trim,
522                  int64_t right_trim, int64_t top_trim, int64_t bottom_trim,
523                  DeviceMemory<float>* output_data) override;
524 
525   bool DoMemcpyD2HQuantized(Stream* stream,
526                             const DeviceMemory<float>& device_unquantized_src,
527                             dnn::QuantizedActivationMode mode, void* host_dst,
528                             int64_t size) override;
529 
530   bool DoMemcpyH2DQuantized(
531       Stream* stream, const void* host_src, int64_t size,
532       dnn::QuantizedActivationMode mode,
533       DeviceMemory<float>* device_unquantized_dst) override;
534 
535   // Derives an output batch descriptor from an input batch and convolution
536   // descriptors.
537   bool DeriveOutputBatchDescriptor(
538       const dnn::BatchDescriptor& batch_descriptor,
539       const dnn::FilterDescriptor& filter_descriptor,
540       const dnn::ConvolutionDescriptor& convolution_descriptor,
541       dnn::BatchDescriptor* output_batch_descriptor);
542 
543   port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
544                          const dnn::RnnStateTensorDescriptor& probs_desc,
545                          const DeviceMemoryBase probs_data,
546                          absl::Span<const int> labels_data,
547                          absl::Span<const int> labels_lengths_data,
548                          absl::Span<const int> input_lengths_data,
549                          DeviceMemoryBase costs_data,
550                          const dnn::RnnStateTensorDescriptor& grads_desc,
551                          DeviceMemoryBase grads_data,
552                          DeviceMemory<uint8> scratch_memory,
553                          int ctc_loss_algo_id) override;
554 
555   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
556                          dnn::DataType input_type,
557                          const DeviceMemoryBase& input_data,
558                          const dnn::BatchDescriptor& output_desc,
559                          dnn::DataType output_type, float scale,
560                          DeviceMemoryBase* output_data) override;
561 
562  private:
563   GpuExecutor* parent_;  // Parent executor object. Not owned.
564 
565   // Provides access to the cuDNN handle.
566   std::unique_ptr<class CudnnAccess> cudnn_;
567 
568   template <class T, class U>
569   port::Status DoBatchNormalizationForwardImpl(
570       Stream* stream, dnn::DataType input_data_type,
571       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
572       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
573       const DeviceMemory<U>& estimated_mean,
574       const DeviceMemory<U>& estimated_variance,
575       const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
576       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
577       const double exponential_average_factor,
578       dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
579       DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
580       DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
581       bool is_training, ScratchAllocator* reserve_space_allocator,
582       ScratchAllocator* workspace_allocator);
583 
584   template <class T, class U>
585   port::Status DoBatchNormalizationBackwardImpl(
586       Stream* stream, int cudnn_input_type, int cudnn_scale_type,
587       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
588       const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
589       const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
590       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
591       DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
592       DeviceMemory<U>* offset_backprop, DeviceMemory<uint8>* reserve_space_data,
593       ScratchAllocator* workspace_allocator);
594 
595   template <class T>
596   port::Status DoRnnForwardImpl(
597       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
598       const CudnnRnnSequenceTensorDescriptor& input_desc,
599       const DeviceMemory<T>& input_data,
600       const DeviceMemory<int>& seq_lengths_data,
601       const CudnnRnnStateTensorDescriptor& input_h_desc,
602       const DeviceMemory<T>& input_h_data,
603       const CudnnRnnStateTensorDescriptor& input_c_desc,
604       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
605       const CudnnRnnSequenceTensorDescriptor& output_desc,
606       DeviceMemory<T>* output_data,
607       const CudnnRnnStateTensorDescriptor& output_h_desc,
608       DeviceMemory<T>* output_h_data,
609       const CudnnRnnStateTensorDescriptor& output_c_desc,
610       DeviceMemory<T>* output_c_data, bool is_training,
611       ScratchAllocator* reserve_space_allocator,
612       ScratchAllocator* workspace_allocator,
613       dnn::ProfileResult* output_profile_result);
614 
615   template <class T>
616   port::Status DoRnnBackwardImpl(
617       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
618       const CudnnRnnSequenceTensorDescriptor& input_desc,
619       const DeviceMemory<T>& input_data,
620       const DeviceMemory<int>& seq_lengths_data,
621       const CudnnRnnStateTensorDescriptor& input_h_desc,
622       const DeviceMemory<T>& input_h_data,
623       const CudnnRnnStateTensorDescriptor& input_c_desc,
624       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
625       const CudnnRnnSequenceTensorDescriptor& output_desc,
626       const DeviceMemory<T>& output_data,
627       const CudnnRnnStateTensorDescriptor& output_h_desc,
628       const DeviceMemory<T>& output_h_data,
629       const CudnnRnnStateTensorDescriptor& output_c_desc,
630       const DeviceMemory<T>& output_c_data,
631       const DeviceMemory<T>& output_backprop_data,
632       const DeviceMemory<T>& output_h_backprop_data,
633       const DeviceMemory<T>& output_c_backprop_data,
634       DeviceMemory<T>* input_backprop_data,
635       DeviceMemory<T>* input_h_backprop_data,
636       DeviceMemory<T>* input_c_backprop_data,
637       DeviceMemory<T>* params_backprop_data,
638       DeviceMemory<uint8>* reserve_space_data,
639       ScratchAllocator* workspace_allocator,
640       dnn::ProfileResult* output_profile_result);
641 
642   port::Status DoCtcLossImpl(
643       Stream* stream, const CudnnRnnStateTensorDescriptor& probs_desc,
644       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
645       absl::Span<const int> labels_lengths_data,
646       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
647       const CudnnRnnStateTensorDescriptor& grads_desc,
648       DeviceMemoryBase grads_data, const CudnnCtcLossDescriptor& ctc_loss_desc,
649       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
650 
651  private:
652   port::Status DoPrepareForConvolution(
653       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
654       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
655       const dnn::FilterDescriptor& filter_descriptor,
656       DeviceMemoryBase filter_data,
657       const dnn::BatchDescriptor& output_descriptor,
658       DeviceMemoryBase output_data,
659       const dnn::ConvolutionDescriptor& convolution_descriptor,
660       const dnn::AlgorithmConfig& algorithm_config,
661       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
662       DeviceMemory<uint8>* scratch_memory) override;
663 
664   port::Status DoPrepareForCtcLoss(
665       Stream* stream, dnn::DataType element_type,
666       const dnn::RnnStateTensorDescriptor& probs_desc,
667       const dnn::RnnStateTensorDescriptor& grads_desc,
668       absl::Span<const int> labels_data,
669       absl::Span<const int> labels_lengths_data,
670       absl::Span<const int> input_lengths_data,
671       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
672       int* ctc_loss_algo_id) override;
673 
674   SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
675 };
676 
677 }  // namespace gpu
678 }  // namespace stream_executor
679 
680 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
681