• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The CUDA-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
20 #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
21 
22 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
23 #include "tensorflow/stream_executor/dnn.h"
24 #include "tensorflow/stream_executor/lib/status.h"
25 #include "tensorflow/stream_executor/platform/mutex.h"
26 #include "tensorflow/stream_executor/platform/thread_annotations.h"
27 #include "tensorflow/stream_executor/plugin_registry.h"
28 #include "tensorflow/stream_executor/temporary_device_memory.h"
29 
30 namespace stream_executor {
31 namespace gpu {
32 
33 class GpuExecutor;
34 class CudnnRnnDescriptor;
35 class CudnnRnnSequenceTensorDescriptor;
36 class CudnnRnnStateTensorDescriptor;
37 
38 // Opaque and unique identifier for the cuDNN plugin.
39 extern const PluginId kCuDnnPlugin;
40 
41 // cudnn-library based DNN support. For details on overridden interface
42 // functions, see dnn.h.
43 class CudnnSupport : public dnn::DnnSupport {
44  public:
45   explicit CudnnSupport(GpuExecutor* parent);
46 
47   port::Status Init() override;
48   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
49 
50   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
51       int num_layers, int hidden_size, int input_size, int batch_size,
52       dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
53       dnn::RnnMode rnn_mode, dnn::DataType data_type,
54       const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
55       ScratchAllocator* state_allocator) override;
56 
57   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
58   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
59                                     int data_size,
60                                     dnn::DataType data_type) override;
61 
62   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
63   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
64                                     int data_size,
65                                     const absl::Span<const int>& seq_lengths,
66                                     bool time_major,
67                                     dnn::DataType data_type) override;
68 
69   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
70   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
71                                  dnn::DataType data_type) override;
72 
73   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
74                     const dnn::RnnSequenceTensorDescriptor& input_desc,
75                     const DeviceMemory<Eigen::half>& input_data,
76                     const dnn::RnnStateTensorDescriptor& input_h_desc,
77                     const DeviceMemory<Eigen::half>& input_h_data,
78                     const dnn::RnnStateTensorDescriptor& input_c_desc,
79                     const DeviceMemory<Eigen::half>& input_c_data,
80                     const DeviceMemory<Eigen::half>& params,
81                     const dnn::RnnSequenceTensorDescriptor& output_desc,
82                     DeviceMemory<Eigen::half>* output_data,
83                     const dnn::RnnStateTensorDescriptor& output_h_desc,
84                     DeviceMemory<Eigen::half>* output_h_data,
85                     const dnn::RnnStateTensorDescriptor& output_c_desc,
86                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
87                     ScratchAllocator* reserve_space_allocator,
88                     ScratchAllocator* workspace_allocator,
89                     dnn::ProfileResult* output_profile_result) override;
90 
91   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
92                     const dnn::RnnSequenceTensorDescriptor& input_desc,
93                     const DeviceMemory<float>& input_data,
94                     const dnn::RnnStateTensorDescriptor& input_h_desc,
95                     const DeviceMemory<float>& input_h_data,
96                     const dnn::RnnStateTensorDescriptor& input_c_desc,
97                     const DeviceMemory<float>& input_c_data,
98                     const DeviceMemory<float>& params,
99                     const dnn::RnnSequenceTensorDescriptor& output_desc,
100                     DeviceMemory<float>* output_data,
101                     const dnn::RnnStateTensorDescriptor& output_h_desc,
102                     DeviceMemory<float>* output_h_data,
103                     const dnn::RnnStateTensorDescriptor& output_c_desc,
104                     DeviceMemory<float>* output_c_data, bool is_training,
105                     ScratchAllocator* reserve_space_allocator,
106                     ScratchAllocator* workspace_allocator,
107                     dnn::ProfileResult* output_profile_result) override;
108 
109   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
110                     const dnn::RnnSequenceTensorDescriptor& input_desc,
111                     const DeviceMemory<double>& input_data,
112                     const dnn::RnnStateTensorDescriptor& input_h_desc,
113                     const DeviceMemory<double>& input_h_data,
114                     const dnn::RnnStateTensorDescriptor& input_c_desc,
115                     const DeviceMemory<double>& input_c_data,
116                     const DeviceMemory<double>& params,
117                     const dnn::RnnSequenceTensorDescriptor& output_desc,
118                     DeviceMemory<double>* output_data,
119                     const dnn::RnnStateTensorDescriptor& output_h_desc,
120                     DeviceMemory<double>* output_h_data,
121                     const dnn::RnnStateTensorDescriptor& output_c_desc,
122                     DeviceMemory<double>* output_c_data, bool is_training,
123                     ScratchAllocator* reserve_space_allocator,
124                     ScratchAllocator* workspace_allocator,
125                     dnn::ProfileResult* output_profile_result) override;
126 
127   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
128                      const dnn::RnnSequenceTensorDescriptor& input_desc,
129                      const DeviceMemory<Eigen::half>& input_data,
130                      const dnn::RnnStateTensorDescriptor& input_h_desc,
131                      const DeviceMemory<Eigen::half>& input_h_data,
132                      const dnn::RnnStateTensorDescriptor& input_c_desc,
133                      const DeviceMemory<Eigen::half>& input_c_data,
134                      const DeviceMemory<Eigen::half>& params,
135                      const dnn::RnnSequenceTensorDescriptor& output_desc,
136                      const DeviceMemory<Eigen::half>& output_data,
137                      const dnn::RnnStateTensorDescriptor& output_h_desc,
138                      const DeviceMemory<Eigen::half>& output_h_data,
139                      const dnn::RnnStateTensorDescriptor& output_c_desc,
140                      const DeviceMemory<Eigen::half>& output_c_data,
141                      const DeviceMemory<Eigen::half>& output_backprop_data,
142                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
143                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
144                      DeviceMemory<Eigen::half>* input_backprop_data,
145                      DeviceMemory<Eigen::half>* input_h_backprop_data,
146                      DeviceMemory<Eigen::half>* input_c_backprop_data,
147                      DeviceMemory<Eigen::half>* params_backprop_data,
148                      DeviceMemory<uint8>* reserve_space_data,
149                      ScratchAllocator* workspace_allocator,
150                      dnn::ProfileResult* output_profile_result) override;
151 
152   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
153                      const dnn::RnnSequenceTensorDescriptor& input_desc,
154                      const DeviceMemory<float>& input_data,
155                      const dnn::RnnStateTensorDescriptor& input_h_desc,
156                      const DeviceMemory<float>& input_h_data,
157                      const dnn::RnnStateTensorDescriptor& input_c_desc,
158                      const DeviceMemory<float>& input_c_data,
159                      const DeviceMemory<float>& params,
160                      const dnn::RnnSequenceTensorDescriptor& output_desc,
161                      const DeviceMemory<float>& output_data,
162                      const dnn::RnnStateTensorDescriptor& output_h_desc,
163                      const DeviceMemory<float>& output_h_data,
164                      const dnn::RnnStateTensorDescriptor& output_c_desc,
165                      const DeviceMemory<float>& output_c_data,
166                      const DeviceMemory<float>& output_backprop_data,
167                      const DeviceMemory<float>& output_h_backprop_data,
168                      const DeviceMemory<float>& output_c_backprop_data,
169                      DeviceMemory<float>* input_backprop_data,
170                      DeviceMemory<float>* input_h_backprop_data,
171                      DeviceMemory<float>* input_c_backprop_data,
172                      DeviceMemory<float>* params_backprop_data,
173                      DeviceMemory<uint8>* reserve_space_data,
174                      ScratchAllocator* workspace_allocator,
175                      dnn::ProfileResult* output_profile_result) override;
176 
177   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
178                      const dnn::RnnSequenceTensorDescriptor& input_desc,
179                      const DeviceMemory<double>& input_data,
180                      const dnn::RnnStateTensorDescriptor& input_h_desc,
181                      const DeviceMemory<double>& input_h_data,
182                      const dnn::RnnStateTensorDescriptor& input_c_desc,
183                      const DeviceMemory<double>& input_c_data,
184                      const DeviceMemory<double>& params,
185                      const dnn::RnnSequenceTensorDescriptor& output_desc,
186                      const DeviceMemory<double>& output_data,
187                      const dnn::RnnStateTensorDescriptor& output_h_desc,
188                      const DeviceMemory<double>& output_h_data,
189                      const dnn::RnnStateTensorDescriptor& output_c_desc,
190                      const DeviceMemory<double>& output_c_data,
191                      const DeviceMemory<double>& output_backprop_data,
192                      const DeviceMemory<double>& output_h_backprop_data,
193                      const DeviceMemory<double>& output_c_backprop_data,
194                      DeviceMemory<double>* input_backprop_data,
195                      DeviceMemory<double>* input_h_backprop_data,
196                      DeviceMemory<double>* input_c_backprop_data,
197                      DeviceMemory<double>* params_backprop_data,
198                      DeviceMemory<uint8>* reserve_space_data,
199                      ScratchAllocator* workspace_allocator,
200                      dnn::ProfileResult* output_profile_result) override;
201 
202   bool GetConvolveAlgorithms(
203       bool with_winograd_nonfused, int cc_major, int cc_minor,
204       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
205 
206   bool GetRnnAlgorithms(
207       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
208 
209   bool GetConvolveBackwardDataAlgorithms(
210       bool with_winograd_nonfused, int cc_major, int cc_minor,
211       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
212 
213   bool GetConvolveBackwardFilterAlgorithms(
214       bool with_winograd_nonfused, int cc_major, int cc_minor,
215       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
216 
217   bool DoBatchNormalizationForward(
218       Stream* stream, const DeviceMemory<float>& x,
219       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
220       const DeviceMemory<float>& estimated_mean,
221       const DeviceMemory<float>& estimated_variance,
222       const dnn::BatchDescriptor& x_desc,
223       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
224       DeviceMemory<float>* y, DeviceMemory<float>* batch_mean,
225       DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
226       DeviceMemory<float>* saved_inv_var, bool is_training,
227       std::function<const DeviceMemory<float>&()> var_to_inv_var,
228       std::function<void()> inv_var_to_var) override;
229 
230   bool DoBatchNormalizationForward(
231       Stream* stream, const DeviceMemory<Eigen::half>& x,
232       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
233       const DeviceMemory<float>& estimated_mean,
234       const DeviceMemory<float>& estimated_variance,
235       const dnn::BatchDescriptor& x_desc,
236       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
237       DeviceMemory<Eigen::half>* y, DeviceMemory<float>* batch_mean,
238       DeviceMemory<float>* batch_var, DeviceMemory<float>* saved_mean,
239       DeviceMemory<float>* saved_inv_var, bool is_training,
240       std::function<const DeviceMemory<float>&()> var_to_inv_var,
241       std::function<void()> inv_var_to_var) override;
242 
243   bool DoBatchNormalizationBackward(
244       Stream* stream, const DeviceMemory<float>& y_backprop,
245       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
246       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
247       const dnn::BatchDescriptor& x_desc,
248       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
249       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
250       DeviceMemory<float>* offset_backprop) override;
251 
252   bool DoBatchNormalizationBackward(
253       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
254       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
255       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
256       const dnn::BatchDescriptor& x_desc,
257       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
258       DeviceMemory<Eigen::half>* x_backprop,
259       DeviceMemory<float>* scale_backprop,
260       DeviceMemory<float>* offset_backprop) override;
261 
262   port::Status DoConvolve(
263       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
264       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
265       const dnn::FilterDescriptor& filter_descriptor,
266       DeviceMemoryBase filter_data,
267       const dnn::BatchDescriptor& output_descriptor,
268       DeviceMemoryBase output_data,
269       const dnn::ConvolutionDescriptor& convolution_descriptor,
270       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
271       dnn::ProfileResult* output_profile_result) override;
272 
273   bool DoFusedConvolve(
274       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
275       const DeviceMemory<double>& conv_input_data, double conv_input_scale,
276       const dnn::FilterDescriptor& filter_descriptor,
277       const DeviceMemory<double>& filter_data,
278       const dnn::ConvolutionDescriptor& convolution_descriptor,
279       const DeviceMemory<double>& side_input_data, double side_input_scale,
280       const dnn::BatchDescriptor& bias_descriptor,
281       const DeviceMemory<double>& biases, dnn::ActivationMode activation_mode,
282       const dnn::BatchDescriptor& output_descriptor,
283       DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
284       const dnn::AlgorithmConfig& algorithm_config,
285       dnn::ProfileResult* output_profile_result) override;
286 
287   bool DoFusedConvolve(
288       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
289       const DeviceMemory<float>& conv_input_data, float conv_input_scale,
290       const dnn::FilterDescriptor& filter_descriptor,
291       const DeviceMemory<float>& filter_data,
292       const dnn::ConvolutionDescriptor& convolution_descriptor,
293       const DeviceMemory<float>& side_input_data, float side_input_scale,
294       const dnn::BatchDescriptor& bias_descriptor,
295       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
296       const dnn::BatchDescriptor& output_descriptor,
297       DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
298       const dnn::AlgorithmConfig& algorithm_config,
299       dnn::ProfileResult* output_profile_result) override;
300 
301   bool DoFusedConvolve(Stream* stream,
302                        const dnn::BatchDescriptor& conv_input_descriptor,
303                        const DeviceMemory<Eigen::half>& conv_input_data,
304                        float conv_input_scale,
305                        const dnn::FilterDescriptor& filter_descriptor,
306                        const DeviceMemory<Eigen::half>& filter_data,
307                        const dnn::ConvolutionDescriptor& convolution_descriptor,
308                        const DeviceMemory<Eigen::half>& side_input_data,
309                        float side_input_scale,
310                        const dnn::BatchDescriptor& bias_descriptor,
311                        const DeviceMemory<Eigen::half>& biases,
312                        dnn::ActivationMode activation_mode,
313                        const dnn::BatchDescriptor& output_descriptor,
314                        DeviceMemory<Eigen::half>* output_data,
315                        ScratchAllocator* scratch_allocator,
316                        const dnn::AlgorithmConfig& algorithm_config,
317                        dnn::ProfileResult* output_profile_result) override;
318 
319   bool DoFusedConvolve(
320       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
321       const DeviceMemory<int8>& conv_input_data, float conv_input_scale,
322       const dnn::FilterDescriptor& filter_descriptor,
323       const DeviceMemory<int8>& filter_data,
324       const dnn::ConvolutionDescriptor& convolution_descriptor,
325       const DeviceMemory<int8>& side_input_data, float side_input_scale,
326       const dnn::BatchDescriptor& bias_descriptor,
327       const DeviceMemory<float>& biases, dnn::ActivationMode activation_mode,
328       const dnn::BatchDescriptor& output_descriptor,
329       DeviceMemory<int8>* output_data, ScratchAllocator* scratch_allocator,
330       const dnn::AlgorithmConfig& algorithm_config,
331       dnn::ProfileResult* output_profile_result) override;
332 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)333   bool DoConvolveQuantized(
334       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
335       const DeviceMemory<float>& input_data,
336       const dnn::FilterDescriptor& filter_descriptor,
337       const DeviceMemory<int8>& filter_coefficients,
338       const DeviceMemory<float>& coefficient_scales,
339       const dnn::ConvolutionDescriptor& convolution_descriptor,
340       const dnn::BatchDescriptor& output_descriptor,
341       DeviceMemory<float>* output_data) override {
342     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
343     return false;
344   }
345 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)346   bool DoConvolveQuantized(
347       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
348       const DeviceMemory<float>& input_data,
349       const dnn::FilterDescriptor& filter_descriptor,
350       const DeviceMemory<int16>& filter_coefficients,
351       const DeviceMemory<float>& coefficient_scales,
352       const dnn::ConvolutionDescriptor& convolution_descriptor,
353       const dnn::BatchDescriptor& output_descriptor,
354       DeviceMemory<float>* output_data) override {
355     LOG(ERROR) << "DoConvolveQuantized not supported by cuDNN";
356     return false;
357   }
358 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)359   bool DoSeparableConvolve(
360       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
361       const DeviceMemory<float>& input_data,
362       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
363       const DeviceMemory<float>& first_weights,
364       const DeviceMemory<float>& second_weights,
365       const dnn::ConvolutionDescriptor& convolution_descriptor,
366       const dnn::BatchDescriptor& output_descriptor,
367       DeviceMemory<float>* output_data) override {
368     LOG(ERROR) << "separable convolution not supported by CUDNN";
369     return false;
370   }
371 
372   bool DoConvolveBackwardBias(
373       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
374       const DeviceMemory<double>& input_data,
375       const dnn::BatchDescriptor& bias_descriptor,
376       DeviceMemory<double>* backward_bias_data) override;
377 
378   bool DoConvolveBackwardBias(Stream* stream,
379                               const dnn::BatchDescriptor& input_descriptor,
380                               const DeviceMemory<float>& input_data,
381                               const dnn::BatchDescriptor& bias_descriptor,
382                               DeviceMemory<float>* backward_bias_data) override;
383 
384   bool DoConvolveBackwardBias(
385       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
386       const DeviceMemory<Eigen::half>& input_data,
387       const dnn::BatchDescriptor& bias_descriptor,
388       DeviceMemory<Eigen::half>* backward_bias_data) override;
389 
390   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
391                 const DeviceMemory<float>& weights,
392                 const dnn::BatchDescriptor& input_dimensions,
393                 const dnn::BatchDescriptor& output_dimensions,
394                 DeviceMemory<float>* output_data) override;
395 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)396   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
397                          const DeviceMemory<int8>& quantized_weights,
398                          const DeviceMemory<float>& weight_scales,
399                          const dnn::BatchDescriptor& input_dimensions,
400                          const dnn::BatchDescriptor& output_dimensions,
401                          DeviceMemory<float>* output_data) override {
402     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
403     return false;
404   }
405 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)406   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
407                          const DeviceMemory<int16>& quantized_weights,
408                          const DeviceMemory<float>& weight_scales,
409                          const dnn::BatchDescriptor& input_dimensions,
410                          const dnn::BatchDescriptor& output_dimensions,
411                          DeviceMemory<float>* output_data) override {
412     LOG(ERROR) << "DNN MatMulQuantized not supported by CUDNN";
413     return false;
414   }
415 
416   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
417                  const DeviceMemory<float>& biases,
418                  const dnn::BatchDescriptor& dimensions,
419                  DeviceMemory<float>* output_data) override;
420 
421   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
422                   const dnn::BatchDescriptor& dimensions,
423                   const DeviceMemory<float>& input_data,
424                   DeviceMemory<float>* output_data, uint64 options) override;
425 
426   bool DoPoolForward(Stream* stream,
427                      const dnn::PoolingDescriptor& pooling_dimensions,
428                      const dnn::BatchDescriptor& input_dimensions,
429                      const DeviceMemory<double>& input_data,
430                      const dnn::BatchDescriptor& output_dimensions,
431                      DeviceMemory<double>* output_data,
432                      ScratchAllocator* workspace_allocator) override;
433 
434   bool DoPoolForward(Stream* stream,
435                      const dnn::PoolingDescriptor& pooling_dimensions,
436                      const dnn::BatchDescriptor& input_dimensions,
437                      const DeviceMemory<float>& input_data,
438                      const dnn::BatchDescriptor& output_dimensions,
439                      DeviceMemory<float>* output_data,
440                      ScratchAllocator* workspace_allocator) override;
441 
442   bool DoPoolForward(Stream* stream,
443                      const dnn::PoolingDescriptor& pooling_dimensions,
444                      const dnn::BatchDescriptor& input_dimensions,
445                      const DeviceMemory<Eigen::half>& input_data,
446                      const dnn::BatchDescriptor& output_dimensions,
447                      DeviceMemory<Eigen::half>* output_data,
448                      ScratchAllocator* workspace_allocator) override;
449 
450   bool DoPoolForward(Stream* stream,
451                      const dnn::PoolingDescriptor& pooling_dimensions,
452                      const dnn::BatchDescriptor& input_dimensions,
453                      const DeviceMemory<int8>& input_data,
454                      const dnn::BatchDescriptor& output_dimensions,
455                      DeviceMemory<int8>* output_data,
456                      ScratchAllocator* workspace_allocator) override;
457 
458   bool DoPoolBackward(Stream* stream,
459                       const dnn::PoolingDescriptor& pooling_dimensions,
460                       const dnn::BatchDescriptor& input_dimensions,
461                       const DeviceMemory<double>& input_data,
462                       const dnn::BatchDescriptor& output_dimensions,
463                       const DeviceMemory<double>& output_data,
464                       const DeviceMemory<double>& input_diff_data,
465                       DeviceMemory<double>* output_diff_data,
466                       ScratchAllocator* workspace_allocator) override;
467 
468   bool DoPoolBackward(Stream* stream,
469                       const dnn::PoolingDescriptor& pooling_dimensions,
470                       const dnn::BatchDescriptor& input_dimensions,
471                       const DeviceMemory<float>& input_data,
472                       const dnn::BatchDescriptor& output_dimensions,
473                       const DeviceMemory<float>& output_data,
474                       const DeviceMemory<float>& input_diff_data,
475                       DeviceMemory<float>* output_diff_data,
476                       ScratchAllocator* workspace_allocator) override;
477 
478   bool DoPoolBackward(Stream* stream,
479                       const dnn::PoolingDescriptor& pooling_dimensions,
480                       const dnn::BatchDescriptor& input_dimensions,
481                       const DeviceMemory<Eigen::half>& input_data,
482                       const dnn::BatchDescriptor& output_dimensions,
483                       const DeviceMemory<Eigen::half>& output_data,
484                       const DeviceMemory<Eigen::half>& input_diff_data,
485                       DeviceMemory<Eigen::half>* output_diff_data,
486                       ScratchAllocator* workspace_allocator) override;
487 
488   bool DoNormalizeWithDimensions(
489       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
490       const dnn::BatchDescriptor& dimensions,
491       const DeviceMemory<float>& input_data,
492       DeviceMemory<float>* output_data) override;
493 
494   bool DoNormalizeBackwardWithDimensions(
495       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
496       const dnn::BatchDescriptor& dimensions,
497       const DeviceMemory<float>& raw_data,
498       const DeviceMemory<float>& normalized_data,
499       const DeviceMemory<float>& normalized_variable_gradient,
500       DeviceMemory<float>* raw_variable_gradient,
501       ScratchAllocator* workspace_allocator) override;
502 
503   bool DoDepthConcatenate(
504       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
505       port::ArraySlice<const DeviceMemory<float>*> input_data,
506       DeviceMemory<float>* output_data) override;
507 
508   bool DoElementwiseOperate(
509       Stream* stream, dnn::ElementwiseOperation operation,
510       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
511       port::ArraySlice<const DeviceMemory<float>*> input_data,
512       const dnn::BatchDescriptor& output_dimensions,
513       DeviceMemory<float>* output_data) override;
514 
515   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions,
516                const DeviceMemory<float> &input_data,
517                int64 left_pad, int64 right_pad, int64 top_pad,
518                int64 bottom_pad, DeviceMemory<float> *output_data) override;
519 
520   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions,
521                  const DeviceMemory<float> &input_data,
522                  int64 left_trim, int64 right_trim, int64 top_trim,
523                  int64 bottom_trim, DeviceMemory<float> *output_data) override;
524 
525   bool DoMemcpyD2HQuantized(Stream* stream,
526                             const DeviceMemory<float>& device_unquantized_src,
527                             dnn::QuantizedActivationMode mode, void* host_dst,
528                             int64 size) override;
529 
530   bool DoMemcpyH2DQuantized(
531       Stream* stream, const void* host_src, int64 size,
532       dnn::QuantizedActivationMode mode,
533       DeviceMemory<float>* device_unquantized_dst) override;
534 
535   // Derives an output batch descriptor from an input batch and convolution
536   // descriptors.
537   bool DeriveOutputBatchDescriptor(
538       const dnn::BatchDescriptor& batch_descriptor,
539       const dnn::FilterDescriptor& filter_descriptor,
540       const dnn::ConvolutionDescriptor& convolution_descriptor,
541       dnn::BatchDescriptor* output_batch_descriptor);
542 
543   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
544                          dnn::DataType input_type,
545                          const DeviceMemoryBase& input_data,
546                          const dnn::BatchDescriptor& output_desc,
547                          dnn::DataType output_type, float scale,
548                          DeviceMemoryBase* output_data) override;
549 
550  private:
551   GpuExecutor* parent_;  // Parent executor object. Not owned.
552 
553   // Provides access to the cuDNN handle.
554   std::unique_ptr<class CudnnAccess> cudnn_;
555 
556   template <class T, class U>
557   port::Status DoBatchNormalizationForwardImpl(
558       Stream* stream, dnn::DataType input_data_type,
559       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
560       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
561       const DeviceMemory<U>& estimated_mean,
562       const DeviceMemory<U>& estimated_variance,
563       const dnn::BatchDescriptor& x_desc,
564       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
565       DeviceMemory<T>* y, DeviceMemory<U>* batch_mean,
566       DeviceMemory<U>* batch_var, DeviceMemory<U>* saved_mean,
567       DeviceMemory<U>* saved_inv_var, bool is_training,
568       std::function<const DeviceMemory<U>&()> var_to_inv_var,
569       std::function<void()> inv_var_to_var);
570 
571   template <class T, class U>
572   port::Status DoBatchNormalizationBackwardImpl(
573       Stream* stream, int cudnn_input_type, int cudnn_scale_type,
574       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
575       const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
576       const DeviceMemory<U>& inv_var, const dnn::BatchDescriptor& x_desc,
577       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
578       DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
579       DeviceMemory<U>* offset_backprop);
580 
581   template <typename ElementType, typename BiasType, typename ScaleType>
582   port::Status DoFusedConvolveImpl(
583       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
584       const DeviceMemory<ElementType>& conv_input_data,
585       ScaleType conv_input_scale,
586       const dnn::FilterDescriptor& filter_descriptor,
587       const DeviceMemory<ElementType>& filter_data,
588       const dnn::ConvolutionDescriptor& convolution_descriptor,
589       const DeviceMemory<ElementType>& side_input_data,
590       ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
591       const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
592       const dnn::BatchDescriptor& output_descriptor,
593       DeviceMemory<ElementType>* output_data, dnn::DataType accumulator_type,
594       ScratchAllocator* scratch_allocator,
595       const dnn::AlgorithmConfig& algorithm_config,
596       dnn::ProfileResult* output_profile_result);
597 
598   template <class T>
599   port::Status DoConvolveBackwardBiasImpl(
600       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
601       const DeviceMemory<T>& input_data,
602       const dnn::BatchDescriptor& bias_descriptor,
603       DeviceMemory<T>* backward_bias_data);
604 
605   template <class T>
606   port::Status DoRnnForwardImpl(
607       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
608       const CudnnRnnSequenceTensorDescriptor& input_desc,
609       const DeviceMemory<T>& input_data,
610       const CudnnRnnStateTensorDescriptor& input_h_desc,
611       const DeviceMemory<T>& input_h_data,
612       const CudnnRnnStateTensorDescriptor& input_c_desc,
613       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
614       const CudnnRnnSequenceTensorDescriptor& output_desc,
615       DeviceMemory<T>* output_data,
616       const CudnnRnnStateTensorDescriptor& output_h_desc,
617       DeviceMemory<T>* output_h_data,
618       const CudnnRnnStateTensorDescriptor& output_c_desc,
619       DeviceMemory<T>* output_c_data, bool is_training,
620       ScratchAllocator* reserve_space_allocator,
621       ScratchAllocator* workspace_allocator,
622       dnn::ProfileResult* output_profile_result);
623 
624   template <class T>
625   port::Status DoRnnBackwardImpl(
626       Stream* stream, const CudnnRnnDescriptor& rnn_desc,
627       const CudnnRnnSequenceTensorDescriptor& input_desc,
628       const DeviceMemory<T>& input_data,
629       const CudnnRnnStateTensorDescriptor& input_h_desc,
630       const DeviceMemory<T>& input_h_data,
631       const CudnnRnnStateTensorDescriptor& input_c_desc,
632       const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
633       const CudnnRnnSequenceTensorDescriptor& output_desc,
634       const DeviceMemory<T>& output_data,
635       const CudnnRnnStateTensorDescriptor& output_h_desc,
636       const DeviceMemory<T>& output_h_data,
637       const CudnnRnnStateTensorDescriptor& output_c_desc,
638       const DeviceMemory<T>& output_c_data,
639       const DeviceMemory<T>& output_backprop_data,
640       const DeviceMemory<T>& output_h_backprop_data,
641       const DeviceMemory<T>& output_c_backprop_data,
642       DeviceMemory<T>* input_backprop_data,
643       DeviceMemory<T>* input_h_backprop_data,
644       DeviceMemory<T>* input_c_backprop_data,
645       DeviceMemory<T>* params_backprop_data,
646       DeviceMemory<uint8>* reserve_space_data,
647       ScratchAllocator* workspace_allocator,
648       dnn::ProfileResult* output_profile_result);
649 
650  private:
651   port::Status DoPrepareForConvolution(
652       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
653       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
654       const dnn::FilterDescriptor& filter_descriptor,
655       DeviceMemoryBase filter_data,
656       const dnn::BatchDescriptor& output_descriptor,
657       DeviceMemoryBase output_data,
658       const dnn::ConvolutionDescriptor& convolution_descriptor,
659       const dnn::AlgorithmConfig& algorithm_config,
660       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
661       DeviceMemory<uint8>* scratch_memory) override;
662 
663   SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
664 };
665 
666 }  // namespace gpu
667 }  // namespace stream_executor
668 
669 #endif  // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
670