• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The ROCM-specific DNN library support, implementing the general DnnSupport
17 // interface.
18 
19 #ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
20 #define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
21 
22 #include "absl/synchronization/mutex.h"
23 #include "rocm/include/miopen/miopen.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 #include "tensorflow/stream_executor/dnn.h"
26 #include "tensorflow/stream_executor/lib/status.h"
27 #include "tensorflow/stream_executor/plugin_registry.h"
28 #include "tensorflow/stream_executor/temporary_device_memory.h"
29 
30 namespace stream_executor {
31 namespace gpu {
32 
33 class GpuExecutor;
34 class MIOpenRnnDescriptor;
35 class MIOpenRnnSequenceTensorDescriptor;
36 class MIOpenRnnStateTensorDescriptor;
37 class MIOpenCTCLossDescriptor;
38 
39 // Opaque and unique identifier for the MIOpen plugin.
40 extern const PluginId kMIOpenPlugin;
41 
42 struct PoolingWorkspaceDescriptor {
43   std::vector<int64> input_dims;
44   std::vector<int64> output_dims;
45   dnn::PoolingDescriptor op;
46   int dtype;
47   uint64_t timestamp;
48   std::unique_ptr<TemporaryDeviceMemory<uint8>> workspace;
49   size_t workspace_size;
50   bool IsSame(const dnn::BatchDescriptor& input_dimensions,
51               const dnn::BatchDescriptor& output_dimensions,
52               const dnn::PoolingDescriptor& pooling_dimensions, int _type);
53 };
54 
55 struct PoolingWorkspaceCache {
56   std::map<const void*, PoolingWorkspaceDescriptor> cache;
57   const int trim_size = 1000;
58   const uint64_t memory_budget = 2e7;
59   uint64_t timestamp = 0;
60   uint64_t memory_used = 0;
61   bool find(const void* p, const dnn::BatchDescriptor& input_dimensions,
62             const dnn::BatchDescriptor& output_dimensions,
63             const dnn::PoolingDescriptor& pooling_dimensions, int _type,
64             PoolingWorkspaceDescriptor*& pdesc);
65   void insert(const void* p, const dnn::BatchDescriptor& input_dimensions,
66               const dnn::BatchDescriptor& output_dimensions,
67               const dnn::PoolingDescriptor& pooling_dimensions, int _type,
68               std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace,
69               size_t wsp_size, hipStream_t hip_stream);
70 
71  private:
72   void trim(hipStream_t hip_stream);
73 };
74 
75 // miopen-library based DNN support. For details on overridden interface
76 // functions, see dnn.h.
77 class MIOpenSupport : public dnn::DnnSupport {
78  public:
79   explicit MIOpenSupport(GpuExecutor* parent);
80 
81   port::Status Init() override;
82   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
83 
84   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
85       int num_layers, int hidden_size, int input_size, int cell_size,
86       int batch_size, dnn::RnnInputMode input_mode,
87       dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
88       dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
89       float dropout, uint64 seed, ScratchAllocator* state_allocator,
90       bool use_padded_io) override;
91 
92   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
93   createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
94                                     int data_size,
95                                     dnn::DataType data_type) override;
96 
97   port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
98   createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size,
99                                  dnn::DataType data_type) override;
100 
101   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
102                     const dnn::RnnSequenceTensorDescriptor& input_desc,
103                     const DeviceMemory<Eigen::half>& input_data,
104                     const DeviceMemory<int>& seq_lengths_data,
105                     const dnn::RnnStateTensorDescriptor& input_h_desc,
106                     const DeviceMemory<Eigen::half>& input_h_data,
107                     const dnn::RnnStateTensorDescriptor& input_c_desc,
108                     const DeviceMemory<Eigen::half>& input_c_data,
109                     const DeviceMemory<Eigen::half>& params,
110                     const dnn::RnnSequenceTensorDescriptor& output_desc,
111                     DeviceMemory<Eigen::half>* output_data,
112                     const dnn::RnnStateTensorDescriptor& output_h_desc,
113                     DeviceMemory<Eigen::half>* output_h_data,
114                     const dnn::RnnStateTensorDescriptor& output_c_desc,
115                     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
116                     ScratchAllocator* reserve_space_allocator,
117                     ScratchAllocator* workspace_allocator,
118                     dnn::ProfileResult* output_profile_result) override;
119 
120   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
121                     const dnn::RnnSequenceTensorDescriptor& input_desc,
122                     const DeviceMemory<float>& input_data,
123                     const DeviceMemory<int>& seq_lengths_data,
124                     const dnn::RnnStateTensorDescriptor& input_h_desc,
125                     const DeviceMemory<float>& input_h_data,
126                     const dnn::RnnStateTensorDescriptor& input_c_desc,
127                     const DeviceMemory<float>& input_c_data,
128                     const DeviceMemory<float>& params,
129                     const dnn::RnnSequenceTensorDescriptor& output_desc,
130                     DeviceMemory<float>* output_data,
131                     const dnn::RnnStateTensorDescriptor& output_h_desc,
132                     DeviceMemory<float>* output_h_data,
133                     const dnn::RnnStateTensorDescriptor& output_c_desc,
134                     DeviceMemory<float>* output_c_data, bool is_training,
135                     ScratchAllocator* reserve_space_allocator,
136                     ScratchAllocator* workspace_allocator,
137                     dnn::ProfileResult* output_profile_result) override;
138 
139   bool DoRnnForward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
140                     const dnn::RnnSequenceTensorDescriptor& input_desc,
141                     const DeviceMemory<double>& input_data,
142                     const DeviceMemory<int>& seq_lengths_data,
143                     const dnn::RnnStateTensorDescriptor& input_h_desc,
144                     const DeviceMemory<double>& input_h_data,
145                     const dnn::RnnStateTensorDescriptor& input_c_desc,
146                     const DeviceMemory<double>& input_c_data,
147                     const DeviceMemory<double>& params,
148                     const dnn::RnnSequenceTensorDescriptor& output_desc,
149                     DeviceMemory<double>* output_data,
150                     const dnn::RnnStateTensorDescriptor& output_h_desc,
151                     DeviceMemory<double>* output_h_data,
152                     const dnn::RnnStateTensorDescriptor& output_c_desc,
153                     DeviceMemory<double>* output_c_data, bool is_training,
154                     ScratchAllocator* reserve_space_allocator,
155                     ScratchAllocator* workspace_allocator,
156                     dnn::ProfileResult* output_profile_result) override;
157 
158   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
159                      const dnn::RnnSequenceTensorDescriptor& input_desc,
160                      const DeviceMemory<Eigen::half>& input_data,
161                      const DeviceMemory<int>& seq_lengths_data,
162                      const dnn::RnnStateTensorDescriptor& input_h_desc,
163                      const DeviceMemory<Eigen::half>& input_h_data,
164                      const dnn::RnnStateTensorDescriptor& input_c_desc,
165                      const DeviceMemory<Eigen::half>& input_c_data,
166                      const DeviceMemory<Eigen::half>& params,
167                      const dnn::RnnSequenceTensorDescriptor& output_desc,
168                      const DeviceMemory<Eigen::half>& output_data,
169                      const dnn::RnnStateTensorDescriptor& output_h_desc,
170                      const DeviceMemory<Eigen::half>& output_h_data,
171                      const dnn::RnnStateTensorDescriptor& output_c_desc,
172                      const DeviceMemory<Eigen::half>& output_c_data,
173                      const DeviceMemory<Eigen::half>& output_backprop_data,
174                      const DeviceMemory<Eigen::half>& output_h_backprop_data,
175                      const DeviceMemory<Eigen::half>& output_c_backprop_data,
176                      DeviceMemory<Eigen::half>* input_backprop_data,
177                      DeviceMemory<Eigen::half>* input_h_backprop_data,
178                      DeviceMemory<Eigen::half>* input_c_backprop_data,
179                      DeviceMemory<Eigen::half>* params_backprop_data,
180                      DeviceMemory<uint8>* reserve_space_data,
181                      ScratchAllocator* workspace_allocator,
182                      dnn::ProfileResult* output_profile_result) override;
183 
184   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
185                      const dnn::RnnSequenceTensorDescriptor& input_desc,
186                      const DeviceMemory<float>& input_data,
187                      const DeviceMemory<int>& seq_lengths_data,
188                      const dnn::RnnStateTensorDescriptor& input_h_desc,
189                      const DeviceMemory<float>& input_h_data,
190                      const dnn::RnnStateTensorDescriptor& input_c_desc,
191                      const DeviceMemory<float>& input_c_data,
192                      const DeviceMemory<float>& params,
193                      const dnn::RnnSequenceTensorDescriptor& output_desc,
194                      const DeviceMemory<float>& output_data,
195                      const dnn::RnnStateTensorDescriptor& output_h_desc,
196                      const DeviceMemory<float>& output_h_data,
197                      const dnn::RnnStateTensorDescriptor& output_c_desc,
198                      const DeviceMemory<float>& output_c_data,
199                      const DeviceMemory<float>& output_backprop_data,
200                      const DeviceMemory<float>& output_h_backprop_data,
201                      const DeviceMemory<float>& output_c_backprop_data,
202                      DeviceMemory<float>* input_backprop_data,
203                      DeviceMemory<float>* input_h_backprop_data,
204                      DeviceMemory<float>* input_c_backprop_data,
205                      DeviceMemory<float>* params_backprop_data,
206                      DeviceMemory<uint8>* reserve_space_data,
207                      ScratchAllocator* workspace_allocator,
208                      dnn::ProfileResult* output_profile_result) override;
209 
210   bool DoRnnBackward(Stream* stream, const dnn::RnnDescriptor& rnn_desc,
211                      const dnn::RnnSequenceTensorDescriptor& input_desc,
212                      const DeviceMemory<double>& input_data,
213                      const DeviceMemory<int>& seq_lengths_data,
214                      const dnn::RnnStateTensorDescriptor& input_h_desc,
215                      const DeviceMemory<double>& input_h_data,
216                      const dnn::RnnStateTensorDescriptor& input_c_desc,
217                      const DeviceMemory<double>& input_c_data,
218                      const DeviceMemory<double>& params,
219                      const dnn::RnnSequenceTensorDescriptor& output_desc,
220                      const DeviceMemory<double>& output_data,
221                      const dnn::RnnStateTensorDescriptor& output_h_desc,
222                      const DeviceMemory<double>& output_h_data,
223                      const dnn::RnnStateTensorDescriptor& output_c_desc,
224                      const DeviceMemory<double>& output_c_data,
225                      const DeviceMemory<double>& output_backprop_data,
226                      const DeviceMemory<double>& output_h_backprop_data,
227                      const DeviceMemory<double>& output_c_backprop_data,
228                      DeviceMemory<double>* input_backprop_data,
229                      DeviceMemory<double>* input_h_backprop_data,
230                      DeviceMemory<double>* input_c_backprop_data,
231                      DeviceMemory<double>* params_backprop_data,
232                      DeviceMemory<uint8>* reserve_space_data,
233                      ScratchAllocator* workspace_allocator,
234                      dnn::ProfileResult* output_profile_result) override;
235 
236   bool GetConvolveAlgorithms(
237       CudaComputeCapability cuda_compute_capability,
238       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
239 
240   bool GetMIOpenConvolveAlgorithms(
241       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
242       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
243       const dnn::FilterDescriptor& filter_descriptor,
244       DeviceMemoryBase filter_data,
245       const dnn::BatchDescriptor& output_descriptor,
246       DeviceMemoryBase output_data,
247       const dnn::ConvolutionDescriptor& convolution_descriptor,
248       ScratchAllocator* scratch_allocator,
249       std::vector<dnn::ProfileResult>* out_algorithms) override;
250 
251   bool GetRnnAlgorithms(
252       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
253 
254   bool GetConvolveBackwardDataAlgorithms(
255       CudaComputeCapability cuda_compute_capability,
256       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
257 
258   bool GetConvolveBackwardFilterAlgorithms(
259       CudaComputeCapability cuda_compute_capability,
260       std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
261 
262   bool DoBatchNormalizationForward(
263       Stream* stream, const DeviceMemory<float>& x,
264       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
265       const DeviceMemory<float>& estimated_mean,
266       const DeviceMemory<float>& estimated_variance,
267       const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
268       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
269       const double exponential_average_factor,
270       dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
271       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
272       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
273       bool is_training, ScratchAllocator* reserve_space_allocator,
274       ScratchAllocator* workspace_allocator) override;
275 
276   bool DoBatchNormalizationForward(
277       Stream* stream, const DeviceMemory<Eigen::half>& x,
278       const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
279       const DeviceMemory<float>& estimated_mean,
280       const DeviceMemory<float>& estimated_variance,
281       const DeviceMemory<Eigen::half>& side_input,
282       const dnn::BatchDescriptor& x_desc,
283       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
284       const double exponential_average_factor,
285       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
286       DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
287       DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
288       bool is_training, ScratchAllocator* reserve_space_allocator,
289       ScratchAllocator* workspace_allocator) override;
290 
291   bool DoBatchNormalizationBackward(
292       Stream* stream, const DeviceMemory<float>& y_backprop,
293       const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
294       const DeviceMemory<float>& mean, const DeviceMemory<float>& variance,
295       const dnn::BatchDescriptor& x_desc,
296       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
297       DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
298       DeviceMemory<float>* offset_backprop,
299       DeviceMemory<uint8>* reserve_space_data,
300       ScratchAllocator* workspace_allocator) override;
301 
302   bool DoBatchNormalizationBackward(
303       Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
304       const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
305       const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
306       const dnn::BatchDescriptor& x_desc,
307       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
308       DeviceMemory<Eigen::half>* x_backprop,
309       DeviceMemory<float>* scale_backprop, DeviceMemory<float>* offset_backprop,
310       DeviceMemory<uint8>* reserve_space_data,
311       ScratchAllocator* workspace_allocator) override;
312 
313   port::Status DoConvolve(
314       dnn::ConvolutionKind kind, dnn::DataType element_type,
315       dnn::DataType output_type, Stream* stream,
316       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
317       const dnn::FilterDescriptor& filter_descriptor,
318       DeviceMemoryBase filter_data,
319       const dnn::BatchDescriptor& output_descriptor,
320       DeviceMemoryBase output_data,
321       const dnn::ConvolutionDescriptor& convolution_descriptor,
322       dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
323       dnn::ProfileResult* output_profile_result) override;
324 
325   port::Status DoFusedConvolve(
326       Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
327       dnn::DataType bias_type, dnn::DataType output_type,
328       const dnn::BatchDescriptor& conv_input_descriptor,
329       DeviceMemoryBase conv_input_data, double conv_input_scale,
330       const dnn::FilterDescriptor& filter_descriptor,
331       DeviceMemoryBase filter_data,
332       const dnn::ConvolutionDescriptor& convolution_descriptor,
333       DeviceMemoryBase side_input_data, double side_input_scale,
334       const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
335       dnn::ActivationMode activation_mode,
336       const dnn::BatchDescriptor& output_descriptor,
337       DeviceMemoryBase output_data, ScratchAllocator* scratch_allocator,
338       const dnn::AlgorithmConfig& algorithm_config,
339       dnn::ProfileResult* output_profile_result) override;
340 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int8> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)341   bool DoConvolveQuantized(
342       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
343       const DeviceMemory<float>& input_data,
344       const dnn::FilterDescriptor& filter_descriptor,
345       const DeviceMemory<int8>& filter_coefficients,
346       const DeviceMemory<float>& coefficient_scales,
347       const dnn::ConvolutionDescriptor& convolution_descriptor,
348       const dnn::BatchDescriptor& output_descriptor,
349       DeviceMemory<float>* output_data) override {
350     LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
351     return false;
352   }
353 
DoConvolveQuantized(Stream * stream,const dnn::BatchDescriptor & input_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<int16> & filter_coefficients,const DeviceMemory<float> & coefficient_scales,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)354   bool DoConvolveQuantized(
355       Stream* stream, const dnn::BatchDescriptor& input_descriptor,
356       const DeviceMemory<float>& input_data,
357       const dnn::FilterDescriptor& filter_descriptor,
358       const DeviceMemory<int16>& filter_coefficients,
359       const DeviceMemory<float>& coefficient_scales,
360       const dnn::ConvolutionDescriptor& convolution_descriptor,
361       const dnn::BatchDescriptor& output_descriptor,
362       DeviceMemory<float>* output_data) override {
363     LOG(ERROR) << "DoConvolveQuantized not supported by MIOpen";
364     return false;
365   }
366 
DoSeparableConvolve(Stream * stream,const dnn::BatchDescriptor & batch_descriptor,const DeviceMemory<float> & input_data,const dnn::FilterDescriptor & filter_descriptor,int depth_multiplier,const DeviceMemory<float> & first_weights,const DeviceMemory<float> & second_weights,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data)367   bool DoSeparableConvolve(
368       Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
369       const DeviceMemory<float>& input_data,
370       const dnn::FilterDescriptor& filter_descriptor, int depth_multiplier,
371       const DeviceMemory<float>& first_weights,
372       const DeviceMemory<float>& second_weights,
373       const dnn::ConvolutionDescriptor& convolution_descriptor,
374       const dnn::BatchDescriptor& output_descriptor,
375       DeviceMemory<float>* output_data) override {
376     LOG(ERROR) << "separable convolution not supported by MIOpen";
377     return false;
378   }
379 
380   bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
381                 const DeviceMemory<float>& weights,
382                 const dnn::BatchDescriptor& input_dimensions,
383                 const dnn::BatchDescriptor& output_dimensions,
384                 DeviceMemory<float>* output_data) override;
385 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int8> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)386   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
387                          const DeviceMemory<int8>& quantized_weights,
388                          const DeviceMemory<float>& weight_scales,
389                          const dnn::BatchDescriptor& input_dimensions,
390                          const dnn::BatchDescriptor& output_dimensions,
391                          DeviceMemory<float>* output_data) override {
392     LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen";
393     return false;
394   }
395 
DoMatMulQuantized(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<int16> & quantized_weights,const DeviceMemory<float> & weight_scales,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)396   bool DoMatMulQuantized(Stream* stream, const DeviceMemory<float>& input_data,
397                          const DeviceMemory<int16>& quantized_weights,
398                          const DeviceMemory<float>& weight_scales,
399                          const dnn::BatchDescriptor& input_dimensions,
400                          const dnn::BatchDescriptor& output_dimensions,
401                          DeviceMemory<float>* output_data) override {
402     LOG(ERROR) << "DNN MatMulQuantized not supported by MIOpen";
403     return false;
404   }
405 
406   bool DoBiasAdd(Stream* stream, const DeviceMemory<float>& input_data,
407                  const DeviceMemory<float>& biases,
408                  const dnn::BatchDescriptor& dimensions,
409                  DeviceMemory<float>* output_data) override;
410 
411   bool DoActivate(Stream* stream, dnn::ActivationMode activation_mode,
412                   const dnn::BatchDescriptor& dimensions,
413                   const DeviceMemory<float>& input_data,
414                   DeviceMemory<float>* output_data, uint64 options) override;
415 
416   bool DoPoolForward(Stream* stream,
417                      const dnn::PoolingDescriptor& pooling_dimensions,
418                      const dnn::BatchDescriptor& input_dimensions,
419                      const DeviceMemory<double>& input_data,
420                      const dnn::BatchDescriptor& output_dimensions,
421                      DeviceMemory<double>* output_data,
422                      ScratchAllocator* workspace_allocator = nullptr) override;
423 
424   bool DoPoolForward(Stream* stream,
425                      const dnn::PoolingDescriptor& pooling_dimensions,
426                      const dnn::BatchDescriptor& input_dimensions,
427                      const DeviceMemory<float>& input_data,
428                      const dnn::BatchDescriptor& output_dimensions,
429                      DeviceMemory<float>* output_data,
430                      ScratchAllocator* workspace_allocator = nullptr) override;
431 
432   bool DoPoolForward(Stream* stream,
433                      const dnn::PoolingDescriptor& pooling_dimensions,
434                      const dnn::BatchDescriptor& input_dimensions,
435                      const DeviceMemory<Eigen::half>& input_data,
436                      const dnn::BatchDescriptor& output_dimensions,
437                      DeviceMemory<Eigen::half>* output_data,
438                      ScratchAllocator* workspace_allocator = nullptr) override;
439 
440   bool DoPoolBackward(Stream* stream,
441                       const dnn::PoolingDescriptor& pooling_dimensions,
442                       const dnn::BatchDescriptor& input_dimensions,
443                       const DeviceMemory<double>& input_data,
444                       const dnn::BatchDescriptor& output_dimensions,
445                       const DeviceMemory<double>& output_data,
446                       const DeviceMemory<double>& input_diff_data,
447                       DeviceMemory<double>* output_diff_data,
448                       ScratchAllocator* workspace_allocator = nullptr) override;
449 
450   bool DoPoolBackward(Stream* stream,
451                       const dnn::PoolingDescriptor& pooling_dimensions,
452                       const dnn::BatchDescriptor& input_dimensions,
453                       const DeviceMemory<float>& input_data,
454                       const dnn::BatchDescriptor& output_dimensions,
455                       const DeviceMemory<float>& output_data,
456                       const DeviceMemory<float>& input_diff_data,
457                       DeviceMemory<float>* output_diff_data,
458                       ScratchAllocator* workspace_allocator = nullptr) override;
459 
460   bool DoPoolBackward(Stream* stream,
461                       const dnn::PoolingDescriptor& pooling_dimensions,
462                       const dnn::BatchDescriptor& input_dimensions,
463                       const DeviceMemory<Eigen::half>& input_data,
464                       const dnn::BatchDescriptor& output_dimensions,
465                       const DeviceMemory<Eigen::half>& output_data,
466                       const DeviceMemory<Eigen::half>& input_diff_data,
467                       DeviceMemory<Eigen::half>* output_diff_data,
468                       ScratchAllocator* workspace_allocator = nullptr) override;
469 
470   bool DoNormalizeWithDimensions(
471       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
472       const dnn::BatchDescriptor& dimensions,
473       const DeviceMemory<float>& input_data,
474       DeviceMemory<float>* output_data) override;
475 
476   bool DoNormalizeBackwardWithDimensions(
477       Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
478       const dnn::BatchDescriptor& dimensions,
479       const DeviceMemory<float>& raw_data,
480       const DeviceMemory<float>& normalized_data,
481       const DeviceMemory<float>& normalized_variable_gradient,
482       DeviceMemory<float>* raw_variable_gradient,
483       ScratchAllocator* workspace_allocator = nullptr) override;
484 
485   bool DoDepthConcatenate(
486       Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
487       port::ArraySlice<const DeviceMemory<float>*> input_data,
488       DeviceMemory<float>* output_data) override;
489 
490   bool DoElementwiseOperate(
491       Stream* stream, dnn::ElementwiseOperation operation,
492       port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
493       port::ArraySlice<const DeviceMemory<float>*> input_data,
494       const dnn::BatchDescriptor& output_dimensions,
495       DeviceMemory<float>* output_data) override;
496 
497   bool DoXYPad(Stream* stream, const dnn::BatchDescriptor& dimensions,
498                const DeviceMemory<float>& input_data, int64 left_pad,
499                int64 right_pad, int64 top_pad, int64 bottom_pad,
500                DeviceMemory<float>* output_data) override;
501 
502   bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor& dimensions,
503                  const DeviceMemory<float>& input_data, int64 left_trim,
504                  int64 right_trim, int64 top_trim, int64 bottom_trim,
505                  DeviceMemory<float>* output_data) override;
506 
507   bool DoMemcpyD2HQuantized(Stream* stream,
508                             const DeviceMemory<float>& device_unquantized_src,
509                             dnn::QuantizedActivationMode mode, void* host_dst,
510                             int64 size) override;
511 
512   bool DoMemcpyH2DQuantized(
513       Stream* stream, const void* host_src, int64 size,
514       dnn::QuantizedActivationMode mode,
515       DeviceMemory<float>* device_unquantized_dst) override;
516 
517   // Derives an output batch descriptor from an input batch and convolution
518   // descriptors.
519   bool DeriveOutputBatchDescriptor(
520       const dnn::BatchDescriptor& batch_descriptor,
521       const dnn::FilterDescriptor& filter_descriptor,
522       const dnn::ConvolutionDescriptor& convolution_descriptor,
523       dnn::BatchDescriptor* output_batch_descriptor);
524 
525   bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc,
526                          dnn::DataType input_type,
527                          const DeviceMemoryBase& input_data,
528                          const dnn::BatchDescriptor& output_desc,
529                          dnn::DataType output_type, float scale,
530                          DeviceMemoryBase* output_data) override;
531 
532   bool DoFusedConvolutionBiasActivation(
533       Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
534       const DeviceMemory<float>& conv_input_data,
535       const dnn::FilterDescriptor& filter_descriptor,
536       const DeviceMemory<float>& filter_data,
537       const dnn::ConvolutionDescriptor& convolution_descriptor,
538       const dnn::BatchDescriptor& bias_descriptor,
539       const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
540       const dnn::BatchDescriptor& output_descriptor,
541       DeviceMemory<float>* output_data,
542       dnn::ProfileResult* output_profile_result) override;
543 
544   bool DoFusedBatchNormActivationInference(
545       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
546       const DeviceMemory<float>& x_data,
547       const dnn::BatchDescriptor& scale_mean_variance_descriptor,
548       const DeviceMemory<float>& scale_data,
549       const DeviceMemory<float>& offset_data,
550       const DeviceMemory<float>& mean_data,
551       const DeviceMemory<float>& variance_data, double epsilon,
552       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
553       dnn::ProfileResult* output_profile_result) override;
554 
555   bool DoFusedBatchNormActivationInference(
556       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
557       const DeviceMemory<Eigen::half>& x_data,
558       const dnn::BatchDescriptor& scale_mean_variance_descriptor,
559       const DeviceMemory<float>& scale_data,
560       const DeviceMemory<float>& offset_data,
561       const DeviceMemory<float>& mean_data,
562       const DeviceMemory<float>& variance_data, double epsilon,
563       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
564       dnn::ProfileResult* output_profile_result) override;
565 
566   bool DoFusedBatchNormActivationForward(
567       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
568       const DeviceMemory<float>& x_data,
569       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
570       const DeviceMemory<float>& scale_data,
571       const DeviceMemory<float>& offset_data, double epsilon,
572       dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
573       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
574       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
575       dnn::ProfileResult* output_profile_result) override;
576 
577   bool DoFusedBatchNormActivationForward(
578       Stream* stream, const dnn::BatchDescriptor& x_descriptor,
579       const DeviceMemory<Eigen::half>& x_data,
580       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
581       const DeviceMemory<float>& scale_data,
582       const DeviceMemory<float>& offset_data, double epsilon,
583       dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
584       DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
585       DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
586       dnn::ProfileResult* output_profile_result) override;
587 
588   bool DoFusedBatchNormActivationBackward(
589       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
590       const DeviceMemory<float>& y_act_backprop_data,
591       const DeviceMemory<float>& y_act_data,
592       dnn::ActivationMode activation_mode, const DeviceMemory<float>& x_bn_data,
593       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
594       const DeviceMemory<float>& scale_data,
595       const DeviceMemory<float>& offset_data,
596       const DeviceMemory<float>& saved_mean_data,
597       const DeviceMemory<float>& saved_var_data,
598       DeviceMemory<float>* x_bn_backprop_data,
599       DeviceMemory<float>* scale_backprop_data,
600       DeviceMemory<float>* offset_backprop_data,
601       dnn::ProfileResult* output_profile_result) override;
602 
603   bool DoFusedBatchNormActivationBackward(
604       Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
605       const DeviceMemory<Eigen::half>& y_act_backprop_data,
606       const DeviceMemory<Eigen::half>& y_act_data,
607       dnn::ActivationMode activation_mode,
608       const DeviceMemory<Eigen::half>& x_bn_data,
609       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
610       const DeviceMemory<float>& scale_data,
611       const DeviceMemory<float>& offset_data,
612       const DeviceMemory<float>& saved_mean_data,
613       const DeviceMemory<float>& saved_var_data,
614       DeviceMemory<Eigen::half>* x_bn_backprop_data,
615       DeviceMemory<float>* scale_backprop_data,
616       DeviceMemory<float>* offset_backprop_data,
617       dnn::ProfileResult* output_profile_result) override;
618 
GetParentExecutor()619   GpuExecutor* GetParentExecutor() { return parent_; }
620 
621   port::Status DoCtcLoss(Stream* stream, dnn::DataType element_type,
622                          const dnn::RnnStateTensorDescriptor& probs_desc,
623                          const DeviceMemoryBase probs_data,
624                          absl::Span<const int> labels_data,
625                          absl::Span<const int> labels_lengths_data,
626                          absl::Span<const int> input_lengths_data,
627                          DeviceMemoryBase costs_data,
628                          const dnn::RnnStateTensorDescriptor& grads_desc,
629                          DeviceMemoryBase grads_data,
630                          DeviceMemory<uint8> scratch_memory,
631                          int ctc_loss_algo_id) override;
632 
633  private:
634   GpuExecutor* parent_;  // Parent executor object. Not owned.
635 
636   // Flag to indicate whether Get*Algorithm routines should only return
637   // the best algorithm (as opposed to a list of all applicable ones)
638   bool return_best_algo_only_;
639 
640   // Flag to indicate whether to use Immediate (or Find) mode for Convolutions
641   bool use_immediate_mode_;
642 
643   // Provide access to the MIOpen handle.
644   std::unique_ptr<class MIOpenAccess> miopen_;
645 
646   PoolingWorkspaceCache m_pooling_cache;
647   bool m_pooling_cache_allowed = false;
648   bool m_pooling_cache_enabled = false;
649 
650   template <class T, class U>
651   bool DoBatchNormalizationForwardImpl(
652       Stream* stream, dnn::DataType input_data_type,
653       dnn::DataType scale_data_type, const DeviceMemory<T>& x,
654       const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
655       const DeviceMemory<U>& estimated_mean,
656       const DeviceMemory<U>& estimated_variance,
657       const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
658       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
659       const double exponential_average_factor,
660       dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
661       DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
662       DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
663       bool is_training);
664 
665   template <class T, class U>
666   bool DoBatchNormalizationBackwardImpl(
667       Stream* stream, int miopen_input_type, int miopen_scale_type,
668       const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
669       const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
670       const DeviceMemory<U>& variance, const dnn::BatchDescriptor& x_desc,
671       const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
672       DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
673       DeviceMemory<U>* offset_backprop);
674 
675   template <class T>
676   bool DoRnnForwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
677                         const MIOpenRnnSequenceTensorDescriptor& input_desc,
678                         const DeviceMemory<T>& input_data,
679                         const MIOpenRnnStateTensorDescriptor& input_h_desc,
680                         const DeviceMemory<T>& input_h_data,
681                         const MIOpenRnnStateTensorDescriptor& input_c_desc,
682                         const DeviceMemory<T>& input_c_data,
683                         const DeviceMemory<T>& params,
684                         const MIOpenRnnSequenceTensorDescriptor& output_desc,
685                         DeviceMemory<T>* output_data,
686                         const MIOpenRnnStateTensorDescriptor& output_h_desc,
687                         DeviceMemory<T>* output_h_data,
688                         const MIOpenRnnStateTensorDescriptor& output_c_desc,
689                         DeviceMemory<T>* output_c_data, bool is_training,
690                         ScratchAllocator* reserve_space_allocator,
691                         ScratchAllocator* workspace_allocator);
692   template <class T>
693   bool DoRnnBackwardImpl(Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
694                          const MIOpenRnnSequenceTensorDescriptor& input_desc,
695                          const DeviceMemory<T>& input_data,
696                          const MIOpenRnnStateTensorDescriptor& input_h_desc,
697                          const DeviceMemory<T>& input_h_data,
698                          const MIOpenRnnStateTensorDescriptor& input_c_desc,
699                          const DeviceMemory<T>& input_c_data,
700                          const DeviceMemory<T>& params,
701                          const MIOpenRnnSequenceTensorDescriptor& output_desc,
702                          const DeviceMemory<T>& output_data,
703                          const MIOpenRnnStateTensorDescriptor& output_h_desc,
704                          const DeviceMemory<T>& output_h_data,
705                          const MIOpenRnnStateTensorDescriptor& output_c_desc,
706                          const DeviceMemory<T>& output_c_data,
707                          const DeviceMemory<T>& output_backprop_data,
708                          const DeviceMemory<T>& output_h_backprop_data,
709                          const DeviceMemory<T>& output_c_backprop_data,
710                          DeviceMemory<T>* input_backprop_data,
711                          DeviceMemory<T>* input_h_backprop_data,
712                          DeviceMemory<T>* input_c_backprop_data,
713                          DeviceMemory<T>* params_backprop_data,
714                          DeviceMemory<uint8>* reserve_space_data,
715                          ScratchAllocator* workspace_allocator);
716 
717   template <typename T>
718   bool DoFusedConvolutionBiasActivationImpl(
719       Stream* stream,
720       int miopen_type,  // Actually miopenDataType_t.
721       const dnn::BatchDescriptor& conv_input_descriptor,
722       const DeviceMemory<T>& conv_input_data,
723       const dnn::FilterDescriptor& filter_descriptor,
724       const DeviceMemory<T>& filter_data,
725       const dnn::ConvolutionDescriptor& convolution_descriptor,
726       const dnn::BatchDescriptor& bias_descriptor,
727       const DeviceMemory<T>& bias_data, dnn::ActivationMode activation_mode,
728       const dnn::BatchDescriptor& output_descriptor,
729       DeviceMemory<T>* output_data, dnn::ProfileResult* output_profile_result);
730 
731   template <typename T, typename U>
732   bool DoFusedBatchNormActivationInferenceImpl(
733       Stream* stream,
734       int miopen_type,  // Actually miopenDataType_t.
735       const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
736       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
737       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
738       const DeviceMemory<U>& mean_data, const DeviceMemory<U>& variance_data,
739       double epsilon, dnn::ActivationMode activation_mode,
740       DeviceMemory<T>* y_data, dnn::ProfileResult* output_profile_result);
741 
742   template <typename T, typename U>
743   bool DoFusedBatchNormActivationForwardImpl(
744       Stream* stream,
745       int miopen_type,  // Actually miopenDataType_t.
746       const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
747       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
748       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
749       double epsilon, dnn::ActivationMode activation_mode,
750       DeviceMemory<T>* y_data, DeviceMemory<U>* batch_mean_data,
751       DeviceMemory<U>* batch_var_data, DeviceMemory<U>* saved_mean_data,
752       DeviceMemory<U>* saved_var_data,
753       dnn::ProfileResult* output_profile_result);
754 
755   template <typename T, typename U>
756   bool DoFusedBatchNormActivationBackwardImpl(
757       Stream* stream,
758       int miopen_type,  // Actually miopenDataType_t.
759       const dnn::BatchDescriptor& y_act_backprop_descriptor,
760       const DeviceMemory<T>& y_act_backprop_data,
761       const DeviceMemory<T>& y_act_data, dnn::ActivationMode activation_mode,
762       const DeviceMemory<T>& x_bn_data,
763       const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
764       const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
765       const DeviceMemory<U>& saved_mean_data,
766       const DeviceMemory<U>& saved_var_data,
767       DeviceMemory<T>* x_bn_backprop_data, DeviceMemory<U>* scale_backprop_data,
768       DeviceMemory<U>* offset_backprop_data,
769       dnn::ProfileResult* output_profile_result);
770 
771   port::Status DoPrepareForConvolution(
772       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
773       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
774       const dnn::FilterDescriptor& filter_descriptor,
775       DeviceMemoryBase filter_data,
776       const dnn::BatchDescriptor& output_descriptor,
777       DeviceMemoryBase output_data,
778       const dnn::ConvolutionDescriptor& convolution_descriptor,
779       const dnn::AlgorithmConfig& algorithm_config,
780       ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
781       DeviceMemory<uint8>* scratch_memory) override;
782 
783   port::Status DoCtcLossImpl(
784       Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
785       const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
786       absl::Span<const int> labels_lengths_data,
787       absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
788       const MIOpenRnnStateTensorDescriptor& grads_desc,
789       DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
790       DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id);
791 
792   port::Status DoPrepareForCtcLoss(
793       Stream* stream, dnn::DataType element_type,
794       const dnn::RnnStateTensorDescriptor& probs_desc,
795       const dnn::RnnStateTensorDescriptor& grads_desc,
796       absl::Span<const int> labels_data,
797       absl::Span<const int> labels_lengths_data,
798       absl::Span<const int> input_lengths_data,
799       ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
800       int* ctc_loss_algo_id) override;
801 
802   bool GetMIOpenConvolveAlgorithmsImmediateMode(
803       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
804       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
805       const dnn::FilterDescriptor& filter_descriptor,
806       DeviceMemoryBase filter_data,
807       const dnn::BatchDescriptor& output_descriptor,
808       DeviceMemoryBase output_data,
809       const dnn::ConvolutionDescriptor& convolution_descriptor,
810       ScratchAllocator* scratch_allocator,
811       std::vector<dnn::ProfileResult>* out_algorithms);
812 
813   bool GetMIOpenConvolveAlgorithmsFindMode(
814       dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
815       const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
816       const dnn::FilterDescriptor& filter_descriptor,
817       DeviceMemoryBase filter_data,
818       const dnn::BatchDescriptor& output_descriptor,
819       DeviceMemoryBase output_data,
820       const dnn::ConvolutionDescriptor& convolution_descriptor,
821       ScratchAllocator* scratch_allocator,
822       std::vector<dnn::ProfileResult>* out_algorithms);
823 
824   template <class T>
825   bool DoPoolBackwardImpl(Stream* stream,
826                           const dnn::PoolingDescriptor& pooling_dimensions,
827                           const dnn::BatchDescriptor& input_dimensions,
828                           const DeviceMemory<T>& input_data,
829                           const dnn::BatchDescriptor& output_dimensions,
830                           const DeviceMemory<T>& output_data,
831                           const DeviceMemory<T>& input_diff_data,
832                           DeviceMemory<T>* output_diff_data,
833                           ScratchAllocator* workspace_allocator = nullptr);
834 
835   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenSupport);
836 };
837 
838 }  // namespace gpu
839 }  // namespace stream_executor
840 
841 #endif  // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DNN_H_
842