• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/rocm/rocm_dnn.h"
17 
18 #include <functional>
19 #include <memory>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "third_party/eigen3/Eigen/Core"
24 #include "rocm/include/miopen/miopen.h"
25 #include "tensorflow/core/lib/hash/hash.h"
26 #include "tensorflow/core/util/env_var.h"
27 #include "tensorflow/stream_executor/dnn.h"
28 #include "tensorflow/stream_executor/gpu/gpu_activation.h"
29 #include "tensorflow/stream_executor/gpu/gpu_driver.h"
30 #include "tensorflow/stream_executor/gpu/gpu_executor.h"
31 #include "tensorflow/stream_executor/gpu/gpu_stream.h"
32 #include "tensorflow/stream_executor/gpu/gpu_timer.h"
33 #include "tensorflow/stream_executor/lib/env.h"
34 #include "tensorflow/stream_executor/lib/error.h"
35 #include "tensorflow/stream_executor/lib/initialize.h"
36 #include "tensorflow/stream_executor/lib/threadpool.h"
37 #include "tensorflow/stream_executor/platform/dso_loader.h"
38 #include "tensorflow/stream_executor/platform/logging.h"
39 #include "tensorflow/stream_executor/plugin_registry.h"
40 #include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
41 #include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
42 #include "tensorflow/stream_executor/scratch_allocator.h"
43 #include "tensorflow/stream_executor/stream.h"
44 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
45 
46 namespace {
47 
48 // Converts (via narrowing) a type T value to a type U, and checks that the
49 // value has no value change due to the conversion.
50 template <typename WideT, typename NarrowT>
CheckedNarrowing(const WideT & wide)51 NarrowT CheckedNarrowing(const WideT& wide) {
52   NarrowT narrow = wide;
53   CHECK_EQ(narrow, wide)
54       << "checked narrowing failed; values not equal post-conversion";
55   return narrow;
56 }
57 
58 const int kConvDebugVlogLevel = 3;
59 
60 }  // namespace
61 
62 namespace stream_executor {
63 
64 using dnn::AlgorithmDesc;
65 using dnn::BatchDescriptor;
66 using dnn::ConvolutionDescriptor;
67 using dnn::FilterDescriptor;
68 using dnn::NormalizeDescriptor;
69 using dnn::PoolingDescriptor;
70 
71 namespace gpu {
72 
73 PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kMIOpenPlugin);
74 
ToString(miopenStatus_t status)75 string ToString(miopenStatus_t status) {
76   switch (status) {
77     case miopenStatusSuccess:
78       return "miopenStatusSuccess";
79     case miopenStatusNotInitialized:
80       return "miopenStatusNotInitialized";
81     case miopenStatusAllocFailed:
82       return "miopenStatusAllocFailed";
83     case miopenStatusBadParm:
84       return "miopenStatusBadParm";
85     case miopenStatusInternalError:
86       return "miopenStatusInternalError";
87     case miopenStatusInvalidValue:
88       return "miopenStatusInvalidValue";
89     case miopenStatusNotImplemented:
90       return "miopenStatusNotImplemented";
91     case miopenStatusUnknownError:
92       return "miopenStatusUnknownError";
93     default:
94       return absl::StrCat("<unknown miopen status: ", static_cast<int>(status),
95                           ">");
96   }
97 }
98 
ToString(miopenConvFwdAlgorithm_t algorithm)99 string ToString(miopenConvFwdAlgorithm_t algorithm) {
100   string s;
101   switch (algorithm) {
102     case miopenConvolutionFwdAlgoGEMM:
103       s = "GEMM";
104       break;
105     case miopenConvolutionFwdAlgoDirect:
106       s = "Direct";
107       break;
108     case miopenConvolutionFwdAlgoFFT:
109       s = "FFT";
110       break;
111     case miopenConvolutionFwdAlgoWinograd:
112       s = "Winograd";
113       break;
114     case miopenConvolutionFwdAlgoImplicitGEMM:
115       s = "Implicit GEMM";
116       break;
117   }
118   return s;
119 }
120 
ToString(miopenConvBwdWeightsAlgorithm_t algorithm)121 string ToString(miopenConvBwdWeightsAlgorithm_t algorithm) {
122   string s;
123   switch (algorithm) {
124     case miopenConvolutionBwdWeightsAlgoGEMM:
125       s = "GEMM";
126       break;
127     case miopenConvolutionBwdWeightsAlgoDirect:
128       s = "Direct";
129       break;
130     case miopenConvolutionBwdWeightsAlgoWinograd:
131       s = "Winograd";
132       break;
133     case miopenConvolutionBwdWeightsAlgoImplicitGEMM:
134       s = "Implicit GEMM";
135       break;
136   }
137   return s;
138 }
139 
ToString(miopenConvBwdDataAlgorithm_t algorithm)140 string ToString(miopenConvBwdDataAlgorithm_t algorithm) {
141   string s;
142   switch (algorithm) {
143     case miopenConvolutionBwdDataAlgoGEMM:
144       s = "GEMM";
145       break;
146     case miopenConvolutionBwdDataAlgoDirect:
147       s = "Direct";
148       break;
149     case miopenConvolutionBwdDataAlgoFFT:
150       s = "FFT";
151       break;
152     case miopenConvolutionBwdDataAlgoWinograd:
153       s = "Winograd";
154       break;
155     case miopenTransposeBwdDataAlgoGEMM:
156       s = "Transpose GEMM";
157       break;
158     case miopenConvolutionBwdDataAlgoImplicitGEMM:
159       s = "Implicit GEMM";
160       break;
161   }
162   return s;
163 }
164 
ToString(miopenConvAlgorithm_t algorithm)165 string ToString(miopenConvAlgorithm_t algorithm) {
166   string s;
167   switch (algorithm) {
168     case miopenConvolutionAlgoGEMM:
169       s = "GEMM";
170       break;
171     case miopenConvolutionAlgoDirect:
172       s = "Direct";
173       break;
174     case miopenConvolutionAlgoFFT:
175       s = "FFT";
176       break;
177     case miopenConvolutionAlgoWinograd:
178       s = "Winograd";
179       break;
180     case miopenConvolutionAlgoImplicitGEMM:
181       s = "Implicit GEMM";
182       break;
183   }
184   return s;
185 }
186 
187 // RAII wrapper for all calls to MIOpen with a MIOpen handle argument.
188 //
189 // See MIOpenAccess::GetHandle() for details.
190 class MIOpenHandle {
191  public:
192   // Takes ownership of the executor context and the lock to access MIOpen
193   // using handle.
MIOpenHandle(gpu::ScopedActivateExecutorContext context,std::unique_ptr<absl::MutexLock> lock,miopenHandle_t handle)194   MIOpenHandle(gpu::ScopedActivateExecutorContext context,
195                std::unique_ptr<absl::MutexLock> lock, miopenHandle_t handle)
196       : context_(std::move(context)), lock_(std::move(lock)), handle_(handle) {}
197 
198   // Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep
199   // a copy.
handle() const200   miopenHandle_t handle() const { return handle_; }
201 
202  private:
203   gpu::ScopedActivateExecutorContext context_;
204   std::unique_ptr<absl::MutexLock> lock_;
205   miopenHandle_t handle_;  // Not owned.
206 };
207 
208 namespace wrap {
209 
210 #ifdef PLATFORM_GOOGLE
211 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)      \
212   struct WrapperShim__##__name {                 \
213     template <typename... Args>                  \
214     miopenStatus_t operator()(Args... args) {    \
215       miopenStatus_t retval = ::__name(args...); \
216       return retval;                             \
217     }                                            \
218   } __name;
219 
220 #else
221 
222 #define STREAM_EXECUTOR_MIOPEN_WRAP(__name)                               \
223   struct DynLoadShim__##__name {                                          \
224     static const char* kName;                                             \
225     using FuncPtrT = std::add_pointer<decltype(::__name)>::type;          \
226     static void* GetDsoHandle() {                                         \
227       auto s = internal::CachedDsoLoader::GetMiopenDsoHandle();           \
228       return s.ValueOrDie();                                              \
229     }                                                                     \
230     static FuncPtrT LoadOrDie() {                                         \
231       void* f;                                                            \
232       auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
233                                                           kName, &f);     \
234       CHECK(s.ok()) << "could not find " << kName                         \
235                     << " in miopen DSO; dlerror: " << s.error_message();  \
236       return reinterpret_cast<FuncPtrT>(f);                               \
237     }                                                                     \
238     static FuncPtrT DynLoad() {                                           \
239       static FuncPtrT f = LoadOrDie();                                    \
240       return f;                                                           \
241     }                                                                     \
242     template <typename... Args>                                           \
243     miopenStatus_t operator()(Args... args) {                             \
244       return DynLoad()(args...);                                          \
245     }                                                                     \
246   } __name;                                                               \
247   const char* DynLoadShim__##__name::kName = #__name;
248 
249 #endif
250 
251 // clang-format off
252 #define MIOPEN_DNN_ROUTINE_EACH(__macro)                             \
253   __macro(miopenBatchNormalizationBackward)                          \
254   __macro(miopenBatchNormalizationForwardInference)                  \
255   __macro(miopenBatchNormalizationForwardTraining)                   \
256   __macro(miopenGetConvolutionForwardOutputDim)                      \
257   __macro(miopenGetConvolutionNdForwardOutputDim)                    \
258   __macro(miopenFindConvolutionForwardAlgorithm)                     \
259   __macro(miopenCreateTensorDescriptor)                              \
260   __macro(miopenDestroyTensorDescriptor)                             \
261   __macro(miopenSetNdPoolingDescriptor)                              \
262   __macro(miopenSetPoolingIndexType)                                 \
263   __macro(miopenSetLRNDescriptor)                                    \
264   __macro(miopenLRNGetWorkSpaceSize)                                 \
265   __macro(miopenCreateConvolutionDescriptor)                         \
266   __macro(miopenCreatePoolingDescriptor)                             \
267   __macro(miopenDestroyPoolingDescriptor)                            \
268   __macro(miopenCreateLRNDescriptor)                                 \
269   __macro(miopenDestroyLRNDescriptor)                                \
270   __macro(miopenDestroyConvolutionDescriptor)                        \
271   __macro(miopenCreateWithStream)                                    \
272   __macro(miopenDestroy)                                             \
273   __macro(miopenSetStream)                                           \
274   __macro(miopenSetAllocator)                                        \
275   __macro(miopenActivationForward)                                   \
276   __macro(miopenConvolutionForward)                                  \
277   __macro(miopenConvolutionBackwardBias)                             \
278   __macro(miopenConvolutionForwardGetWorkSpaceSize)                  \
279   __macro(miopenInitConvolutionDescriptor)                           \
280   __macro(miopenInitConvolutionNdDescriptor)                         \
281   __macro(miopenGetConvolutionDescriptor)                            \
282   __macro(miopenGetConvolutionNdDescriptor)                          \
283   __macro(miopenSetConvolutionGroupCount)                            \
284   __macro(miopenSet4dTensorDescriptor)                               \
285   __macro(miopenGetTensorDescriptor)                                 \
286   __macro(miopenSetTensorDescriptor)                                 \
287   __macro(miopenGetTensorDescriptorSize)                             \
288   __macro(miopenPoolingForward)                                      \
289   __macro(miopenPoolingGetWorkSpaceSizeV2)                           \
290   __macro(miopenPoolingBackward)                                     \
291   __macro(miopenLRNForward)                                          \
292   __macro(miopenLRNBackward)                                         \
293   __macro(miopenOpTensor)                                            \
294   __macro(miopenConvolutionBackwardData)                             \
295   __macro(miopenConvolutionBackwardWeights)                          \
296   __macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize)          \
297   __macro(miopenFindConvolutionBackwardDataAlgorithm)                \
298   __macro(miopenFindConvolutionBackwardWeightsAlgorithm)             \
299   __macro(miopenConvolutionBackwardDataGetWorkSpaceSize)             \
300   __macro(miopenCreateRNNDescriptor)                                 \
301   __macro(miopenSetRNNDescriptor)                                    \
302   __macro(miopenDestroyRNNDescriptor)                                \
303   __macro(miopenGetRNNParamsSize)                                    \
304   __macro(miopenGetRNNLayerParam)                                    \
305   __macro(miopenGetRNNLayerBias)                                     \
306   __macro(miopenGetRNNWorkspaceSize)                                 \
307   __macro(miopenGetRNNTrainingReserveSize)                           \
308   __macro(miopenRNNForwardInference)                                 \
309   __macro(miopenRNNForwardTraining)                                  \
310   __macro(miopenRNNBackwardData)                                     \
311   __macro(miopenRNNBackwardWeights)                                  \
312   __macro(miopenGetRNNLayerParamOffset)                              \
313   __macro(miopenGetRNNLayerParamSize)                                \
314   __macro(miopenGetRNNLayerBiasOffset)                               \
315   __macro(miopenGetRNNLayerBiasSize)                                 \
316   __macro(miopenGetRNNParamsDescriptor)                              \
317   __macro(miopenCreateActivationDescriptor)                          \
318   __macro(miopenSetActivationDescriptor)                             \
319   __macro(miopenGetActivationDescriptor)                             \
320   __macro(miopenDestroyActivationDescriptor)                         \
321   __macro(miopenCreateFusionPlan)                                    \
322   __macro(miopenCreateOpConvForward)                                 \
323   __macro(miopenCreateOpBiasForward)                                 \
324   __macro(miopenCreateOpActivationForward)                           \
325   __macro(miopenCreateOpActivationBackward)                          \
326   __macro(miopenCreateOpBatchNormInference)                          \
327   __macro(miopenCreateOpBatchNormForward)                            \
328   __macro(miopenCreateOpBatchNormBackward)                           \
329   __macro(miopenCompileFusionPlan)                                   \
330   __macro(miopenFusionPlanGetOp)                                     \
331   __macro(miopenCreateOperatorArgs)                                  \
332   __macro(miopenSetOpArgsConvForward)                                \
333   __macro(miopenSetOpArgsBiasForward)                                \
334   __macro(miopenSetOpArgsActivForward)                               \
335   __macro(miopenSetOpArgsActivBackward)                              \
336   __macro(miopenSetOpArgsBatchNormInference)                         \
337   __macro(miopenSetOpArgsBatchNormForward)                           \
338   __macro(miopenSetOpArgsBatchNormBackward)                          \
339   __macro(miopenExecuteFusionPlan)                                   \
340   __macro(miopenDestroyOperatorArgs)                                 \
341   __macro(miopenDestroyFusionPlan)                                   \
342   __macro(miopenConvolutionForwardGetSolutionCount)                  \
343   __macro(miopenConvolutionForwardGetSolution)                       \
344   __macro(miopenConvolutionForwardGetSolutionWorkspaceSize)          \
345   __macro(miopenConvolutionForwardCompileSolution)                   \
346   __macro(miopenConvolutionForwardImmediate)                         \
347   __macro(miopenConvolutionBackwardDataGetSolutionCount)             \
348   __macro(miopenConvolutionBackwardDataGetSolution)                  \
349   __macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize)     \
350   __macro(miopenConvolutionBackwardDataCompileSolution)              \
351   __macro(miopenConvolutionBackwardDataImmediate)                    \
352   __macro(miopenConvolutionBackwardWeightsGetSolutionCount)          \
353   __macro(miopenConvolutionBackwardWeightsGetSolution)               \
354   __macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize)  \
355   __macro(miopenConvolutionBackwardWeightsCompileSolution)           \
356   __macro(miopenConvolutionBackwardWeightsImmediate)                 \
357   __macro(miopenCreateCTCLossDescriptor)                             \
358   __macro(miopenSetCTCLossDescriptor)                                \
359   __macro(miopenGetCTCLossWorkspaceSize)                             \
360   __macro(miopenCTCLoss)                                             \
361   __macro(miopenDestroyCTCLossDescriptor)
362 // clang-format on
363 
364 MIOPEN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_MIOPEN_WRAP)
365 
366 #undef MIOPEN_DNN_ROUTINE_EACH
367 
368 }  // namespace wrap
369 
370 namespace {
371 
372 // These routines should ideally be provided as an MIOpen API.
373 // They are called for *every* _ROCMmFusedOp*::Compute call, and they need to be
374 // efficient! Instead of calculating the hash value by quering the MIOpen Get*
375 // APIs for the descriptor components, it would be a lot more efficient if,
376 // MIOpen calculated the hash value when creating the descriptor, stored it on
377 // the descriptor datastructure, and provided an API routine to query it.
378 
379 const int kMaxMIOpenTensorSize = 5;
380 
GetHashValue(miopenTensorDescriptor_t tensor_desc)381 uint64 GetHashValue(miopenTensorDescriptor_t tensor_desc) {
382   miopenDataType_t datatype = miopenFloat;
383   int dims[kMaxMIOpenTensorSize] = {0};
384   int strides[kMaxMIOpenTensorSize] = {0};
385   wrap::miopenGetTensorDescriptor(tensor_desc, &datatype, dims, strides);
386 
387   uint64 hash_value = tensorflow::hash<int>()(datatype);
388   for (int dim : dims)
389     hash_value =
390         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(dim));
391   for (int stride : strides)
392     hash_value =
393         tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(stride));
394 
395   return hash_value;
396 }
397 
GetHashValue(miopenConvolutionDescriptor_t conv_desc)398 uint64 GetHashValue(miopenConvolutionDescriptor_t conv_desc) {
399   miopenConvolutionMode_t c_mode = miopenConvolution;
400   int nd = 0;
401   wrap::miopenGetConvolutionNdDescriptor(conv_desc, 0, &nd, nullptr, nullptr,
402                                          nullptr, &c_mode);
403 
404   std::vector<int> stride(nd);
405   std::vector<int> pad(nd);
406   std::vector<int> dilation(nd);
407 
408   wrap::miopenGetConvolutionNdDescriptor(
409       conv_desc, nd, &nd, pad.data(), stride.data(), dilation.data(), &c_mode);
410 
411   uint64 hash_value = tensorflow::hash<int>()(c_mode);
412   auto hash64Combine = [&hash_value](int element) {
413     tensorflow::Hash64Combine(hash_value, tensorflow::hash<int>()(element));
414   };
415   std::for_each(pad.begin(), pad.end(), hash64Combine);
416   std::for_each(stride.begin(), stride.end(), hash64Combine);
417   std::for_each(dilation.begin(), dilation.end(), hash64Combine);
418 
419   return hash_value;
420 }
421 
422 // Class to implement a cache of compiled fusion plans
423 class CachedFusionPlans {
424  public:
425   // Check if we already have a fusion_plan corresponding to the given hash
426   // value.
427   // If we do, then
428   //   return true (+ the cached fusion plan via given pointer)
429   // Else
430   //   create a new fusion plan descriptor,
431   //   associate it with the given hash value in the cache
432   //   return false (+ newly created fusion plan via given pointer)
FindOrCreate(uint64 hash,miopenFusionPlanDescriptor_t * fusion_plan,miopenFusionDirection_t fusion_direction,miopenTensorDescriptor_t input_descriptor)433   static bool FindOrCreate(uint64 hash,
434                            miopenFusionPlanDescriptor_t* fusion_plan,
435                            miopenFusionDirection_t fusion_direction,
436                            miopenTensorDescriptor_t input_descriptor) {
437     absl::MutexLock lock{&cached_plans_mutex};
438 
439     bool found_cached_plan = false;
440 
441     auto it = cached_plans.find(hash);
442     if (it != cached_plans.end()) {
443       *fusion_plan = it->second;
444       found_cached_plan = true;
445     } else {
446       auto status = wrap::miopenCreateFusionPlan(fusion_plan, fusion_direction,
447                                                  input_descriptor);
448       if (status != miopenStatusSuccess) {
449         LOG(FATAL) << "call to miopenCreateFusionPlan failed: "
450                    << ToString(status);
451       } else {
452         cached_plans[hash] = *fusion_plan;
453       }
454     }
455 
456     return found_cached_plan;
457   }
458 
459   // Need to figure out the right place to call this routine
Clear()460   static void Clear() {
461     absl::MutexLock lock{&cached_plans_mutex};
462 
463     for (auto it : cached_plans) {
464       auto status = wrap::miopenDestroyFusionPlan(it.second);
465       if (status != miopenStatusSuccess) {
466         LOG(FATAL) << "call to miopenDestroyFusionPlan failed: "
467                    << ToString(status);
468       }
469     }
470 
471     cached_plans.clear();
472 
473     unsupported_plans.clear();
474   }
475 
476   // Is the Fusion plan corresponding to this hash unsupported
IsUnsupportedFusionPlan(uint64 hash)477   static bool IsUnsupportedFusionPlan(uint64 hash) {
478     absl::MutexLock lock{&cached_plans_mutex};
479     return unsupported_plans.count(hash) > 0;
480   }
481 
482   // Mark the given hash value as corresponding to an unsupported fusion plan
MarkFusionPlanUnsupported(uint64 hash)483   static void MarkFusionPlanUnsupported(uint64 hash) {
484     absl::MutexLock lock{&cached_plans_mutex};
485     unsupported_plans.insert(hash);
486   }
487 
488  private:
489   // Mutex to guard access to all data within this class
490   static absl::Mutex cached_plans_mutex;
491 
492   // Map of hash-value to MIOpen Fusion plan descriptors
493   // Need to be able share this across more than one stream and hence static
494   static std::map<uint64, miopenFusionPlanDescriptor_t> cached_plans;
495 
496   // Set of hash-values that correspond to MIOpen Fusion plans that will fail
497   // compile and hence are not supported.
498   static std::set<uint64> unsupported_plans;
499 };
500 
501 absl::Mutex CachedFusionPlans::cached_plans_mutex;
502 std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
503 std::set<uint64> CachedFusionPlans::unsupported_plans;
504 
GetProfileResultFromConvSolution(miopenConvSolution_t solution)505 dnn::ProfileResult GetProfileResultFromConvSolution(
506     miopenConvSolution_t solution) {
507   dnn::ProfileResult profile_result;
508   profile_result.set_algorithm(
509       {static_cast<AlgorithmDesc::Index>(solution.solution_id), false});
510   profile_result.set_elapsed_time_in_ms(solution.time);
511   profile_result.set_scratch_size(solution.workspace_size);
512   return profile_result;
513 }
514 
GetProfileResultFromConvAlgoPerf(dnn::ConvolutionKind kind,miopenConvAlgoPerf_t algorithm)515 dnn::ProfileResult GetProfileResultFromConvAlgoPerf(
516     dnn::ConvolutionKind kind, miopenConvAlgoPerf_t algorithm) {
517   dnn::ProfileResult profile_result;
518   switch (kind) {
519     case dnn::ConvolutionKind::FORWARD:
520       profile_result.set_algorithm(
521           {static_cast<AlgorithmDesc::Index>(algorithm.fwd_algo), false});
522       break;
523     case dnn::ConvolutionKind::BACKWARD_DATA:
524       profile_result.set_algorithm(
525           {static_cast<AlgorithmDesc::Index>(algorithm.bwd_data_algo), false});
526       break;
527     case dnn::ConvolutionKind::BACKWARD_FILTER:
528       profile_result.set_algorithm(
529           {static_cast<AlgorithmDesc::Index>(algorithm.bwd_weights_algo),
530            false});
531       break;
532     default:
533       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
534       break;
535   }
536   profile_result.set_elapsed_time_in_ms(algorithm.time);
537   profile_result.set_scratch_size(algorithm.memory);
538   return profile_result;
539 }
540 }  // namespace
541 
542 // Wraps a MIOpen handle and provides access to it through miopenHandle_t
543 // instances, which also locks a mutex, acquires the ROCm context, and sets
544 // the stream that MIOpen should use to enqueue any work.
545 //
546 // Note: MIOpenSupport::miopen_ should be the only instantiation of this class.
547 class MIOpenAccess {
548  public:
549   // Takes ownership of the handle.
MIOpenAccess(miopenHandle_t handle)550   explicit MIOpenAccess(miopenHandle_t handle) : handle_(handle) {}
551 
~MIOpenAccess()552   ~MIOpenAccess() {
553     absl::MutexLock lock(&mutex_);
554     wrap::miopenDestroy(handle_);
555   }
556 
557   // Creates a MIOpenHandle instance for stream.
558   //
559   // MIOpen API calls using the same handle instance need to be serialized
560   // across threads. This is guaranteed by MIOpenHandle instances locking the
561   // mutex owned by this class.
562   //
563   // Most MIOpen APIs taking a handle perform work on a HIP stream. The
564   // MIOpenHandle instance acquires the executor's ROCm context and sets MIOpen
565   // to use the provided stream.
566   //
567   // The stream argument may be null, which translates to the null stream.
568   // The null stream synchronizes with all other streams and it is
569   // therefore a bad idea (performance wise) to call any MIOpen APIs that
570   // enqueue work in the stream.
GetHandle(GpuExecutor * executor,Stream * stream)571   MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) {
572     auto lock = absl::make_unique<absl::MutexLock>(&mutex_);
573     mutex_.AssertHeld();
574     gpu::ScopedActivateExecutorContext context(executor);
575     hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr;
576     auto status = wrap::miopenSetStream(handle_, hip_stream);
577     CHECK_EQ(status, miopenStatusSuccess) << "Failed to set MIOpen stream.";
578     return MIOpenHandle(std::move(context), std::move(lock), handle_);
579   }
580 
581  private:
582   // Guards the enqueueing of MIOpen operations via the handle_ below.
583   absl::Mutex mutex_;
584 
585   // MIOpen library handle.
586   miopenHandle_t handle_ TF_GUARDED_BY(mutex_);  // Owned.
587 };
588 
MIOpenSupport(GpuExecutor * parent)589 MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) {
590   // by default, the Get*Algorithm API will return the list of all applicable
591   // algorithms
592   return_best_algo_only_ = false;
593   // but if the env var TF_ROCM_RETURN_BEST_ALGO_ONLY is set, only the best
594   // (i.e. most efficient) algorithm will be returned
595   tensorflow::ReadBoolFromEnvVar("TF_ROCM_RETURN_BEST_ALGO_ONLY", false,
596                                  &return_best_algo_only_);
597 
598   // by default, use Find Mode APIs for convolution
599   use_immediate_mode_ = false;
600   // swich to Find Mode if env var TF_ROCM_USE_IMMEDIATE_MODE is set
601   tensorflow::ReadBoolFromEnvVar("TF_ROCM_USE_IMMEDIATE_MODE", false,
602                                  &use_immediate_mode_);
603 
604   bool enable_pooling_cache = false;
605   tensorflow::ReadBoolFromEnvVar("TF_ROCM_BW_POOL_CACHE", false,
606                                  &enable_pooling_cache);
607   if (enable_pooling_cache) m_pooling_cache_allowed = true;
608 }
609 
Init()610 port::Status MIOpenSupport::Init() {
611   ScopedActivateExecutorContext context(parent_);
612   miopenHandle_t miopen_handle = nullptr;
613   auto status = wrap::miopenCreateWithStream(
614       reinterpret_cast<miopenHandle_t*>(&miopen_handle), (hipStream_t)(0));
615   if (status == miopenStatusSuccess) {
616     miopen_.reset(new MIOpenAccess(miopen_handle));
617     return port::Status::OK();
618   }
619 
620   CHECK_EQ(miopen_handle, nullptr);
621   LOG(ERROR) << "could not create miopen handle: " << ToString(status);
622   if (status == miopenStatusNotInitialized) {
623     auto result = rocm::Diagnostician::FindKernelDriverVersion();
624     if (!result.ok()) {
625       LOG(ERROR) << "error retrieving driver version: "
626                  << rocm::DriverVersionStatusToString(result);
627     } else {
628       const auto& version = result.ValueOrDie();
629       LOG(INFO) << "possibly insufficient driver version: "
630                 << rocm::DriverVersionToString(version);
631     }
632   }
633 
634   return port::Status{port::error::INTERNAL,
635                       absl::StrCat("miopen library could not create a handle: ",
636                                    ToString(status))};
637 }
638 
639 port::StatusOr<perftools::gputools::dnn::VersionInfo>
GetVersion()640 MIOpenSupport::GetVersion() {
641   // ROCM TODO: retrieve MIOpen version with its API
642   return perftools::gputools::dnn::VersionInfo(1, 3, 0);
643 }
644 
645 // Turns a BatchDescriptor structure into a miopen tensor handle within a scope.
646 class ScopedTensorDescriptor {
647  public:
ScopedTensorDescriptor(const BatchDescriptor & batch_descriptor,miopenDataType_t elem_type)648   ScopedTensorDescriptor(const BatchDescriptor& batch_descriptor,
649                          miopenDataType_t elem_type)
650       : handle_(nullptr) {
651     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
652     if (status != miopenStatusSuccess) {
653       LOG(FATAL) << "could not create miopen tensor descriptor: "
654                  << ToString(status);
655     }
656 
657     switch (batch_descriptor.layout()) {
658       case dnn::DataLayout::kBatchYXDepth:
659       case dnn::DataLayout::kBatchDepthYX: {
660         const int nd = batch_descriptor.ndims() + 2;
661 
662         // MIOpen requires the strides and dims to be ordered as BDYX.
663         std::vector<int64> strides64 =
664             batch_descriptor.full_strides(dnn::DataLayout::kBatchDepthYX);
665         std::vector<int64> dims64 =
666             batch_descriptor.full_dims(dnn::DataLayout::kBatchDepthYX);
667 
668         // MIOpen requires arrays of ints.
669         std::vector<int> strides(nd);
670         std::vector<int> dims(nd);
671         std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
672                        &CheckedNarrowing<int64, int>);
673         std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
674                        &CheckedNarrowing<int64, int>);
675         status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
676                                                  dims.data(), strides.data());
677 
678         if (status != miopenStatusSuccess) {
679           LOG(FATAL) << "could not convert BatchDescriptor "
680                      << batch_descriptor.ToString()
681                      << " to miopen tensor descriptor: " << ToString(status);
682         }
683       } break;
684       default:
685         LOG(FATAL) << "Unsupported tensor format "
686                    << DataLayoutString(batch_descriptor.layout());
687         break;
688     }
689   }
690 
~ScopedTensorDescriptor()691   ~ScopedTensorDescriptor() {
692     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
693     if (status != miopenStatusSuccess) {
694       LOG(ERROR) << "could not destroy miopen tensor descriptor: "
695                  << ToString(status);
696     }
697   }
698 
handle() const699   miopenTensorDescriptor_t handle() const { return handle_; }
700 
701  private:
702   miopenTensorDescriptor_t handle_;  // Owned.
703 
704   SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
705 };
706 
707 // Turns a FilterDescriptor structure into a miopen filter handle within a
708 // scope.
709 class ScopedFilterDescriptor {
710  public:
ScopedFilterDescriptor(const FilterDescriptor & filter_descriptor,miopenDataType_t elem_type)711   ScopedFilterDescriptor(const FilterDescriptor& filter_descriptor,
712                          miopenDataType_t elem_type)
713       : handle_(nullptr) {
714     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
715     if (status != miopenStatusSuccess) {
716       LOG(FATAL) << "could not create miopen filter descriptor: "
717                  << ToString(status);
718     }
719 
720     // We need to pass two vectors to the miopenSetTensorDescriptor routine
721     // "dims" (length == number of dims, elem value == dimension size)
722     // "strides" (length == number of dims, elem value == stride size)
723     //
724     // Irrespective of the actual filter layout, the indexing of both those
725     // vectors must be the following (coz that is what MIOpen expects)
726     // dims[0] = strides[0] = N or output
727     // dims[1] = strides[1] = C or input
728     // dims[2] = strides[2] = H or spatial dim 0
729     // dims[3] = strides[3] = W or spatial dim 1
730     //
731     // assume you have a tensor with dimensions
732     // batch descriptor name    filter descriptor name    value
733     //   N (batch size)            O (output features)    256
734     //   C (channels)              I (input features)       3
735     //   H (height)                H (height)               7
736     //   W (width)                 W (width)                5
737     //
738     // The content of "dims" will be the same irrespective of layout
739     // layout (NCHW or NHWC), and MIOpen expects it should be
740     //                           NCHW layout   NHWC layout
741     // dims[0] = size of N dim =    256           256
742     // dims[1] = size of C dim =      3             3
743     // dims[2] = size of H dim =      7             7
744     // dims[3] = size of W dim =      5             5
745     //
746     // The content of "strides" will be different based on layout
747     //                                  NCHW layout   NHWC layout
748     //  strides[0] = stride of N dim =     7x5x3       7x5x3
749     //  strides[1] = stride of C dim =     7x5         1
750     //  strides[2] = stride of H dim =     5           5x3
751     //  strides[3] = stride of W dim =     1           3
752 
753     switch (filter_descriptor.layout()) {
754       case dnn::FilterLayout::kOutputYXInput:
755       case dnn::FilterLayout::kOutputInputYX: {
756         const int nd = filter_descriptor.ndims() + 2;
757 
758         // MIOpen requires the strides and dims to be ordered as BDYX.
759         std::vector<int64> strides64 =
760             filter_descriptor.full_strides(dnn::FilterLayout::kOutputInputYX);
761         std::vector<int64> dims64 =
762             filter_descriptor.full_dims(dnn::FilterLayout::kOutputInputYX);
763 
764         // MIOpen requires arrays of ints.
765         std::vector<int> strides;
766         std::vector<int> dims;
767         absl::c_transform(strides64, std::back_inserter(strides),
768                           &CheckedNarrowing<int64, int>);
769         absl::c_transform(dims64, std::back_inserter(dims),
770                           &CheckedNarrowing<int64, int>);
771         status = wrap::miopenSetTensorDescriptor(handle_, elem_type, nd,
772                                                  dims.data(), strides.data());
773 
774         if (status != miopenStatusSuccess) {
775           LOG(FATAL) << "could not convert FilterDescriptor "
776                      << filter_descriptor.ToString()
777                      << " to miopen tensor descriptor: " << ToString(status);
778         }
779       } break;
780       default:
781         LOG(FATAL) << "Unsupported tensor format "
782                    << FilterLayoutString(filter_descriptor.layout());
783         break;
784     }
785   }
786 
~ScopedFilterDescriptor()787   ~ScopedFilterDescriptor() {
788     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
789     if (status != miopenStatusSuccess) {
790       LOG(ERROR) << "could not destroy miopen filter descriptor: "
791                  << ToString(status);
792     }
793   }
794 
handle() const795   miopenTensorDescriptor_t handle() const { return handle_; }
796 
797  private:
798   // miopen filter descriptor this object creates. Owned.
799   miopenTensorDescriptor_t handle_;
800 
801   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
802 };
803 
804 // Turns a ConvolutionDescriptor structure into a miopen convolution handle
805 // within a scope.
806 class ScopedConvolutionDescriptor {
807  public:
ScopedConvolutionDescriptor(const ConvolutionDescriptor & convolution_descriptor,miopenDataType_t data_type)808   ScopedConvolutionDescriptor(
809       const ConvolutionDescriptor& convolution_descriptor,
810       miopenDataType_t data_type)
811       : handle_(nullptr) {
812     auto status = wrap::miopenCreateConvolutionDescriptor(&handle_);
813     if (status != miopenStatusSuccess) {
814       LOG(FATAL) << "could not create miopen convolution descriptor: "
815                  << ToString(status);
816     }
817     const auto& strides64 = convolution_descriptor.strides();
818     const auto& padding64 = convolution_descriptor.padding();
819     if (convolution_descriptor.pad_alignment() ==
820         dnn::PadAlignment::kTensorFlowPadding) {
821       LOG(ERROR) << "TensorFlow padding alignment is not supported.";
822     }
823 
824     // MIOpen requires arrays of ints.
825     std::vector<int> strides(convolution_descriptor.ndims());
826     std::vector<int> padding(convolution_descriptor.ndims());
827     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
828                    &CheckedNarrowing<int64, int>);
829     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
830                    &CheckedNarrowing<int64, int>);
831 
832     std::vector<int> upscale(convolution_descriptor.ndims());
833     const auto& dilations64 = convolution_descriptor.dilations();
834     std::transform(dilations64.cbegin(), dilations64.cend(), upscale.begin(),
835                    &CheckedNarrowing<int64, int>);
836 
837     status = wrap::miopenInitConvolutionNdDescriptor(
838         handle_, convolution_descriptor.ndims(), padding.data(), strides.data(),
839         upscale.data(), miopenConvolution);
840     if (status != miopenStatusSuccess) {
841       LOG(FATAL) << "could not set miopen convolution descriptor: "
842                  << ToString(status);
843     }
844 
845     VLOG(2) << "Requesting grouped convolution: "
846             << convolution_descriptor.group_count();
847     status = wrap::miopenSetConvolutionGroupCount(
848         handle_, convolution_descriptor.group_count());
849     if (status != miopenStatusSuccess) {
850       LOG(FATAL) << "could not set miopen convolution group count: "
851                  << ToString(status);
852     }
853   }
~ScopedConvolutionDescriptor()854   ~ScopedConvolutionDescriptor() {
855     auto status = wrap::miopenDestroyConvolutionDescriptor(handle_);
856     if (status != miopenStatusSuccess) {
857       LOG(ERROR) << "could not destroy miopen convolution descriptor: "
858                  << ToString(status);
859     }
860   }
861 
handle() const862   miopenConvolutionDescriptor_t handle() const { return handle_; }
863 
864  private:
865   miopenConvolutionDescriptor_t handle_;  // Owned.
866 
867   SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
868 };
869 
870 // Turns a PoolingDescriptor structure into a miopen pooling descriptor handle
871 // within a scope.
872 class ScopedPoolingDescriptor {
873  public:
ScopedPoolingDescriptor(const PoolingDescriptor & pooling_descriptor)874   ScopedPoolingDescriptor(const PoolingDescriptor& pooling_descriptor)
875       : handle_(nullptr) {
876     auto status = wrap::miopenCreatePoolingDescriptor(&handle_);
877     if (status != miopenStatusSuccess) {
878       LOG(FATAL) << "could not create miopen pooling descriptor: "
879                  << ToString(status);
880     }
881 
882     absl::Span<const int64> strides64 = pooling_descriptor.strides();
883     absl::Span<const int64> padding64 = pooling_descriptor.padding();
884     absl::Span<const int64> shape64 = pooling_descriptor.window();
885 
886     const int nd = pooling_descriptor.ndims();
887     std::vector<int> shape(nd);
888     std::vector<int> padding(nd);
889     std::vector<int> strides(nd);
890     std::transform(strides64.cbegin(), strides64.cend(), strides.begin(),
891                    &CheckedNarrowing<int64, int>);
892     std::transform(padding64.cbegin(), padding64.cend(), padding.begin(),
893                    &CheckedNarrowing<int64, int>);
894     std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
895                    &CheckedNarrowing<int64, int>);
896 
897     status = wrap::miopenSetNdPoolingDescriptor(
898         handle_,
899         (pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
900              ? miopenPoolingMax
901              : miopenPoolingAverage),
902         nd, shape.data(), padding.data(), strides.data());
903 
904     // Note: The index type has to be uint32 type for now because MIOpen
905     // API assumes all input indexes to be the same type. Since a tensor
906     // descriptor can only use int32 type, the index type here need to be
907     // aligned with the tensor index type of the (input) tensor descritptor
908     status = wrap::miopenSetPoolingIndexType(handle_, miopenIndexUint32);
909 
910     if (status != miopenStatusSuccess) {
911       LOG(FATAL) << "could not set miopen pooling descriptor: "
912                  << ToString(status);
913     }
914   }
~ScopedPoolingDescriptor()915   ~ScopedPoolingDescriptor() {
916     auto status = wrap::miopenDestroyPoolingDescriptor(handle_);
917     if (status != miopenStatusSuccess) {
918       LOG(ERROR) << "could not destroy miopen pooling descriptor: "
919                  << ToString(status);
920     }
921   }
922 
handle() const923   miopenPoolingDescriptor_t handle() const { return handle_; }
924 
925  private:
926   miopenPoolingDescriptor_t handle_;  // Owned.
927 
928   SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
929 };
930 
931 // Turns a NormalizeDescriptor structure into a miopen LRN descriptor handle.
932 class ScopedNormalizeDescriptor {
933  public:
ScopedNormalizeDescriptor(const NormalizeDescriptor & normalize_descriptor)934   ScopedNormalizeDescriptor(const NormalizeDescriptor& normalize_descriptor)
935       : handle_(nullptr) {
936     auto status = wrap::miopenCreateLRNDescriptor(&handle_);
937     if (status != miopenStatusSuccess) {
938       LOG(FATAL) << "could not create miopen LRN descriptor: "
939                  << ToString(status);
940     }
941 
942     // The range specifies that the indices in the closed range
943     // [i - range, i + range] should be included in the normalization for index
944     // i. The lrnN value is the total number of elements in the range, so
945     // lrnN = 2*range + 1.
946     unsigned lrn_N = 2 * normalize_descriptor.range() + 1;
947 
948     // Note that SE defines the normalization operation as
949     //
950     //  U_i = V_i / ((bias +  alpha      * (sum_j V_j^2)) ^ beta)
951     //
952     // but MIOpen defines it as
953     //
954     //  U_i = V_i / ((bias + (alpha / n) * (sum_j V_j^2)) ^ beta)
955     //
956     // i.e. there is a factor of n difference between the meaning of the alphas
957     // in the two contexts. The MIOpen alpha is n times the SE alpha.
958     double lrn_alpha = lrn_N * normalize_descriptor.alpha();
959 
960     double lrn_beta = normalize_descriptor.beta();
961     double lrn_k = normalize_descriptor.bias();
962     status = wrap::miopenSetLRNDescriptor(handle_, miopenLRNCrossChannel, lrn_N,
963                                           lrn_alpha, lrn_beta, lrn_k);
964     if (status != miopenStatusSuccess) {
965       LOG(FATAL) << "could not set miopen LRN descriptor: " << ToString(status);
966     }
967   }
968 
~ScopedNormalizeDescriptor()969   ~ScopedNormalizeDescriptor() {
970     auto status = wrap::miopenDestroyLRNDescriptor(handle_);
971     if (status != miopenStatusSuccess) {
972       LOG(ERROR) << "could not destroy miopen LRN descriptor: "
973                  << ToString(status);
974     }
975   }
976 
handle() const977   miopenLRNDescriptor_t handle() const { return handle_; }
978 
979  private:
980   miopenLRNDescriptor_t handle_;  // Owned.
981 
982   SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
983 };
984 
985 // Turns a activation mode into a miopen activation mode descriptor with a scope
986 // around it
987 class ScopedActivationDescriptor {
988  public:
ScopedActivationDescriptor(dnn::ActivationMode activation_mode)989   ScopedActivationDescriptor(dnn::ActivationMode activation_mode)
990       : handle_(nullptr),
991         miopen_activation_mode_(miopenActivationPASTHRU),
992         alpha_(0.0),
993         beta_(0.0),
994         gamma_(0.0) {
995     auto status = wrap::miopenCreateActivationDescriptor(&handle_);
996     if (status != miopenStatusSuccess) {
997       LOG(FATAL) << "call to miopenCreateActivationDescriptor failed: "
998                  << ToString(status);
999     } else {
1000       switch (activation_mode) {
1001         case dnn::ActivationMode::kNone:
1002           miopen_activation_mode_ = miopenActivationPASTHRU;
1003           break;
1004 
1005         case dnn::ActivationMode::kSigmoid:
1006           miopen_activation_mode_ = miopenActivationLOGISTIC;
1007           break;
1008 
1009         case dnn::ActivationMode::kRelu:
1010           miopen_activation_mode_ = miopenActivationRELU;
1011           break;
1012 
1013         case dnn::ActivationMode::kRelu6:
1014           miopen_activation_mode_ = miopenActivationRELU;
1015           alpha_ = 6.0;
1016           break;
1017 
1018         case dnn::ActivationMode::kTanh:
1019           miopen_activation_mode_ = miopenActivationTANH;
1020           break;
1021 
1022         default:
1023           LOG(FATAL) << "Activation mode ("
1024                      << dnn::ActivationModeString(activation_mode)
1025                      << ") not yet implemented";
1026           break;
1027       }
1028 
1029       status = wrap::miopenSetActivationDescriptor(
1030           handle_, miopen_activation_mode_, alpha_, beta_, gamma_);
1031       if (status != miopenStatusSuccess) {
1032         LOG(FATAL) << "call to miopenSetActivationDescriptor failed: "
1033                    << ToString(status);
1034       }
1035     }
1036   }
1037 
~ScopedActivationDescriptor()1038   ~ScopedActivationDescriptor() {
1039     auto status = wrap::miopenDestroyActivationDescriptor(handle_);
1040     if (status != miopenStatusSuccess) {
1041       LOG(FATAL) << "call to miopenDestroyActivationDescriptor failed: "
1042                  << ToString(status);
1043     }
1044   }
1045 
handle() const1046   miopenActivationDescriptor_t handle() const { return handle_; }
1047 
GetHashValue()1048   uint64 GetHashValue() {
1049     uint64 hash_value = tensorflow::hash<int>()(miopen_activation_mode_);
1050     hash_value = tensorflow::Hash64Combine(hash_value,
1051                                            tensorflow::hash<double>()(alpha_));
1052     hash_value = tensorflow::Hash64Combine(hash_value,
1053                                            tensorflow::hash<double>()(beta_));
1054     hash_value = tensorflow::Hash64Combine(hash_value,
1055                                            tensorflow::hash<double>()(gamma_));
1056 
1057     return hash_value;
1058   }
1059 
1060  private:
1061   miopenActivationDescriptor_t handle_;  // Owned.
1062 
1063   SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
1064 
1065  public:
1066   // caching these values here to avoid calling miopenGetActivationDescriptor
1067   // to do the same. miopenGetActivationDescriptor gets called twice during each
1068   // call to execute a fusion plan (that involves the activation op)...once call
1069   // during calculating hashvalue for the fusion op, and another before calling
1070   // SetOpArgs for the activation op
1071   miopenActivationMode_t miopen_activation_mode_;
1072   double alpha_;
1073   double beta_;
1074   double gamma_;
1075 };
1076 
1077 // base class for all fusion plan implementations to derive from
1078 class ScopedFusionPlanBase {
1079  public:
ScopedFusionPlanBase(miopenHandle_t miopen_handle,const miopenFusionDirection_t fuse_direction,const miopenTensorDescriptor_t input_descriptor)1080   ScopedFusionPlanBase(miopenHandle_t miopen_handle,
1081                        const miopenFusionDirection_t fuse_direction,
1082                        const miopenTensorDescriptor_t input_descriptor)
1083       : miopen_handle_(miopen_handle),
1084         fusion_plan_(nullptr),
1085         fusion_args_(nullptr),
1086         fusion_plan_compiled_(false) {
1087     auto status = wrap::miopenCreateOperatorArgs(&fusion_args_);
1088     if (status != miopenStatusSuccess) {
1089       LOG(FATAL) << "call to miopenCreateOperatorArgs failed: "
1090                  << ToString(status);
1091     }
1092   }
1093 
~ScopedFusionPlanBase()1094   virtual ~ScopedFusionPlanBase() {
1095     auto status = wrap::miopenDestroyOperatorArgs(fusion_args_);
1096     if (status != miopenStatusSuccess) {
1097       LOG(FATAL) << "call to miopenDestroyoperatorArgs failed: "
1098                  << ToString(status);
1099     }
1100   }
1101 
Execute(miopenTensorDescriptor_t input_descriptor,const void * input_data,miopenTensorDescriptor_t output_descriptor,void * output_data)1102   miopenStatus_t Execute(miopenTensorDescriptor_t input_descriptor,
1103                          const void* input_data,
1104                          miopenTensorDescriptor_t output_descriptor,
1105                          void* output_data) {
1106     auto status = wrap::miopenExecuteFusionPlan(
1107         miopen_handle_, fusion_plan_, input_descriptor, input_data,
1108         output_descriptor, output_data, fusion_args_);
1109     if (status != miopenStatusSuccess) {
1110       LOG(FATAL) << "call to miopenExecuteFusionPlan failed: "
1111                  << ToString(status);
1112     }
1113 
1114     return status;
1115   }
1116 
CompilationSucceeded()1117   bool CompilationSucceeded() { return fusion_plan_compiled_; }
1118 
1119  protected:
SetConvolutionArgs(const int op_idx,const float * alpha,const float * beta,const void * data)1120   miopenStatus_t SetConvolutionArgs(const int op_idx, const float* alpha,
1121                                     const float* beta, const void* data) {
1122     miopenFusionOpDescriptor_t conv_op;
1123     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &conv_op);
1124     if (status != miopenStatusSuccess) {
1125       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1126                  << ToString(status);
1127     }
1128 
1129     status = wrap::miopenSetOpArgsConvForward(fusion_args_, conv_op, alpha,
1130                                               beta, data);
1131     if (status != miopenStatusSuccess) {
1132       LOG(FATAL) << "call to miopenSetOpArgsConvForward failed: "
1133                  << ToString(status);
1134     }
1135     return status;
1136   }
1137 
SetBiasArgs(const int op_idx,const float * alpha,const float * beta,const void * data)1138   miopenStatus_t SetBiasArgs(const int op_idx, const float* alpha,
1139                              const float* beta, const void* data) {
1140     miopenFusionOpDescriptor_t bias_op;
1141     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &bias_op);
1142     if (status != miopenStatusSuccess) {
1143       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1144                  << ToString(status);
1145     }
1146 
1147     status = wrap::miopenSetOpArgsBiasForward(fusion_args_, bias_op, alpha,
1148                                               beta, data);
1149     if (status != miopenStatusSuccess) {
1150       LOG(FATAL) << "call to miopenSetOpArgsBiasForward failed: "
1151                  << ToString(status);
1152     }
1153     return status;
1154   }
1155 
SetBatchNormInferenceArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)1156   miopenStatus_t SetBatchNormInferenceArgs(const int op_idx, const float* alpha,
1157                                            const float* beta, const void* scale,
1158                                            const void* offset, const void* mean,
1159                                            const void* variance,
1160                                            double epsilon) {
1161     miopenFusionOpDescriptor_t batchnorm_op;
1162     auto status =
1163         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1164     if (status != miopenStatusSuccess) {
1165       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1166                  << ToString(status);
1167     }
1168 
1169     status = wrap::miopenSetOpArgsBatchNormInference(fusion_args_, batchnorm_op,
1170                                                      alpha, beta, scale, offset,
1171                                                      mean, variance, epsilon);
1172     if (status != miopenStatusSuccess) {
1173       LOG(FATAL) << "call to miopenSetOpArgsBatchNormInference failed: "
1174                  << ToString(status);
1175     }
1176     return status;
1177   }
1178 
SetBatchNormForwardArgs(const int op_idx,const float * alpha,const float * beta,const void * scale,const void * offset,void * running_mean,void * running_variance,void * saved_mean,void * saved_inv_variance,double epsilon,double exponential_average_factor)1179   miopenStatus_t SetBatchNormForwardArgs(
1180       const int op_idx, const float* alpha, const float* beta,
1181       const void* scale, const void* offset, void* running_mean,
1182       void* running_variance, void* saved_mean, void* saved_inv_variance,
1183       double epsilon, double exponential_average_factor) {
1184     miopenFusionOpDescriptor_t batchnorm_op;
1185     auto status =
1186         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1187     if (status != miopenStatusSuccess) {
1188       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1189                  << ToString(status);
1190     }
1191 
1192     status = wrap::miopenSetOpArgsBatchNormForward(
1193         fusion_args_, batchnorm_op, alpha, beta, scale, offset, saved_mean,
1194         saved_inv_variance, running_mean, running_variance, epsilon,
1195         exponential_average_factor);
1196     if (status != miopenStatusSuccess) {
1197       LOG(FATAL) << "call to miopenSetOpArgsBatchNormForward failed: "
1198                  << ToString(status);
1199     }
1200     return status;
1201   }
1202 
SetBatchNormBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * x,const void * scale,const void * offset,void * scale_grad,void * offset_grad,const void * saved_mean,const void * saved_inv_variance)1203   miopenStatus_t SetBatchNormBackwardArgs(const int op_idx, const float* alpha,
1204                                           const float* beta, const void* x,
1205                                           const void* scale, const void* offset,
1206                                           void* scale_grad, void* offset_grad,
1207                                           const void* saved_mean,
1208                                           const void* saved_inv_variance) {
1209     miopenFusionOpDescriptor_t batchnorm_op;
1210     auto status =
1211         wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &batchnorm_op);
1212     if (status != miopenStatusSuccess) {
1213       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1214                  << ToString(status);
1215     }
1216 
1217     status = wrap::miopenSetOpArgsBatchNormBackward(
1218         fusion_args_, batchnorm_op, alpha, beta, x, scale, offset, scale_grad,
1219         offset_grad, saved_mean, saved_inv_variance);
1220     if (status != miopenStatusSuccess) {
1221       LOG(FATAL) << "call to miopenSetOpArgsBatchNormBackward failed: "
1222                  << ToString(status);
1223     }
1224     return status;
1225   }
1226 
SetActivationForwardArgs(const int op_idx,const float * alpha,const float * beta,double activ_alpha,double activ_beta,double activ_gamma)1227   miopenStatus_t SetActivationForwardArgs(const int op_idx, const float* alpha,
1228                                           const float* beta, double activ_alpha,
1229                                           double activ_beta,
1230                                           double activ_gamma) {
1231     miopenFusionOpDescriptor_t actv_op;
1232     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1233     if (status != miopenStatusSuccess) {
1234       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1235                  << ToString(status);
1236     }
1237 
1238     status =
1239         wrap::miopenSetOpArgsActivForward(fusion_args_, actv_op, alpha, beta,
1240                                           activ_alpha, activ_beta, activ_gamma);
1241     if (status != miopenStatusSuccess) {
1242       LOG(FATAL) << "call to miopenSetOpArgsActivForward failed: "
1243                  << ToString(status);
1244     }
1245     return status;
1246   }
1247 
SetActivationBackwardArgs(const int op_idx,const float * alpha,const float * beta,const void * y,double activ_alpha,double activ_beta,double activ_gamma)1248   miopenStatus_t SetActivationBackwardArgs(const int op_idx, const float* alpha,
1249                                            const float* beta, const void* y,
1250                                            double activ_alpha,
1251                                            double activ_beta,
1252                                            double activ_gamma) {
1253     miopenFusionOpDescriptor_t actv_op;
1254     auto status = wrap::miopenFusionPlanGetOp(fusion_plan_, op_idx, &actv_op);
1255     if (status != miopenStatusSuccess) {
1256       LOG(FATAL) << "call to miopenFusionPlanGetOp failed: "
1257                  << ToString(status);
1258     }
1259 
1260     status = wrap::miopenSetOpArgsActivBackward(fusion_args_, actv_op, alpha,
1261                                                 beta, y, nullptr, activ_alpha,
1262                                                 activ_beta, activ_gamma);
1263     if (status != miopenStatusSuccess) {
1264       LOG(FATAL) << "call to miopenSetOpArgsActivBackward failed: "
1265                  << ToString(status);
1266     }
1267     return status;
1268   }
1269 
1270   miopenHandle_t miopen_handle_;
1271   miopenFusionPlanDescriptor_t fusion_plan_;
1272   miopenOperatorArgs_t fusion_args_;  // Owned.
1273   bool fusion_plan_compiled_;
1274 
1275   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBase);
1276 };
1277 
1278 // class to represent the Convolution+Bias+Activation fusion plan
1279 class ScopedFusionPlanConvolutionBiasActivation : public ScopedFusionPlanBase {
1280  public:
ScopedFusionPlanConvolutionBiasActivation(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1281   ScopedFusionPlanConvolutionBiasActivation(
1282       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1283       miopenTensorDescriptor_t filter_descriptor,
1284       miopenConvolutionDescriptor_t conv_descriptor,
1285       miopenTensorDescriptor_t bias_descriptor,
1286       ScopedActivationDescriptor& activation_descriptor)
1287       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1288                              input_descriptor) {
1289     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1290                                        filter_descriptor, conv_descriptor,
1291                                        bias_descriptor, activation_descriptor);
1292 
1293     bool is_compiled = CachedFusionPlans::FindOrCreate(
1294         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1295     if (!is_compiled) {
1296       miopenFusionOpDescriptor_t conv_op;
1297       auto status = wrap::miopenCreateOpConvForward(
1298           fusion_plan_, &conv_op, conv_descriptor, filter_descriptor);
1299       if (status != miopenStatusSuccess) {
1300         LOG(FATAL) << "call to miopenCreateOpConvForward failed: "
1301                    << ToString(status);
1302       }
1303 
1304       miopenFusionOpDescriptor_t bias_op;
1305       status = wrap::miopenCreateOpBiasForward(fusion_plan_, &bias_op,
1306                                                bias_descriptor);
1307       if (status != miopenStatusSuccess) {
1308         LOG(FATAL) << "call to miopenCreateOpBiasForward failed: "
1309                    << ToString(status);
1310       }
1311 
1312       miopenFusionOpDescriptor_t actv_op;
1313       status = wrap::miopenCreateOpActivationForward(
1314           fusion_plan_, &actv_op,
1315           activation_descriptor.miopen_activation_mode_);
1316       if (status != miopenStatusSuccess) {
1317         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1318                    << ToString(status);
1319       }
1320 
1321       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1322       if (status != miopenStatusSuccess) {
1323         VLOG(2) << "call to miopenCompileFusionPlan (CBA) failed: "
1324                 << ToString(status);
1325 
1326         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1327       } else {
1328         VLOG(2) << "Fusion Plan compile succedded (CBA) ";
1329         fusion_plan_compiled_ = true;
1330       }
1331     } else {
1332       // fusion plan was already compiled...check whether it failed to compile
1333       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1334     }
1335   }
1336 
SetConvolutionArgs(const void * filter_data)1337   miopenStatus_t SetConvolutionArgs(const void* filter_data) {
1338     float alpha = 1.0;
1339     float beta = 0.0;
1340     return ScopedFusionPlanBase::SetConvolutionArgs(k_conv_op_idx, &alpha,
1341                                                     &beta, filter_data);
1342   }
1343 
SetBiasArgs(const void * bias_data)1344   miopenStatus_t SetBiasArgs(const void* bias_data) {
1345     float alpha = 1.0;
1346     float beta = 0.0;
1347     return ScopedFusionPlanBase::SetBiasArgs(k_bias_op_idx, &alpha, &beta,
1348                                              bias_data);
1349   }
1350 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1351   miopenStatus_t SetActivationForwardArgs(
1352       ScopedActivationDescriptor& activation_descriptor) {
1353     float alpha = 1.0;
1354     float beta = 0.0;
1355 
1356     return ScopedFusionPlanBase::SetActivationForwardArgs(
1357         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1358         activation_descriptor.beta_, activation_descriptor.gamma_);
1359   }
1360 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t filter_descriptor,miopenConvolutionDescriptor_t conv_descriptor,miopenTensorDescriptor_t bias_descriptor,ScopedActivationDescriptor & activation_descriptor)1361   uint64 GetFusionOpHashValue(
1362       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1363       miopenTensorDescriptor_t filter_descriptor,
1364       miopenConvolutionDescriptor_t conv_descriptor,
1365       miopenTensorDescriptor_t bias_descriptor,
1366       ScopedActivationDescriptor& activation_descriptor) {
1367     uint64 hash_value = tensorflow::Hash64("ConvolutionBiasActivation");
1368 
1369     hash_value = tensorflow::Hash64Combine(
1370         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1371 
1372     hash_value =
1373         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1374     hash_value =
1375         tensorflow::Hash64Combine(hash_value, GetHashValue(filter_descriptor));
1376     hash_value =
1377         tensorflow::Hash64Combine(hash_value, GetHashValue(conv_descriptor));
1378     hash_value =
1379         tensorflow::Hash64Combine(hash_value, GetHashValue(bias_descriptor));
1380     hash_value = tensorflow::Hash64Combine(
1381         hash_value, activation_descriptor.GetHashValue());
1382     return hash_value;
1383   }
1384 
1385  private:
1386   const int k_conv_op_idx = 0;
1387   const int k_bias_op_idx = 1;
1388   const int k_actv_op_idx = 2;
1389 
1390   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanConvolutionBiasActivation);
1391 };
1392 
1393 // class to represent the BatchNorm+Activation (inference) fusion plan
1394 class ScopedFusionPlanBatchNormActivationInference
1395     : public ScopedFusionPlanBase {
1396  public:
ScopedFusionPlanBatchNormActivationInference(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1397   ScopedFusionPlanBatchNormActivationInference(
1398       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1399       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1400       ScopedActivationDescriptor& activation_descriptor)
1401       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1402                              input_descriptor) {
1403     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1404                                        scale_offset_mean_variance_descriptor,
1405                                        activation_descriptor);
1406 
1407     bool is_compiled = CachedFusionPlans::FindOrCreate(
1408         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1409 
1410     if (!is_compiled) {
1411       miopenFusionOpDescriptor_t batchnorm_op;
1412       auto status = wrap::miopenCreateOpBatchNormInference(
1413           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1414           scale_offset_mean_variance_descriptor);
1415 
1416       if (status != miopenStatusSuccess) {
1417         LOG(FATAL) << "call to miopenCreateOpBatchNormInference failed: "
1418                    << ToString(status);
1419       }
1420 
1421       miopenFusionOpDescriptor_t actv_op;
1422       status = wrap::miopenCreateOpActivationForward(
1423           fusion_plan_, &actv_op,
1424           activation_descriptor.miopen_activation_mode_);
1425       if (status != miopenStatusSuccess) {
1426         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1427                    << ToString(status);
1428       }
1429 
1430       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1431       if (status != miopenStatusSuccess) {
1432         VLOG(2) << "call to miopenCompileFusionPlan (BnA inference) failed: "
1433                 << ToString(status);
1434 
1435         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1436       } else {
1437         VLOG(2) << "Fusion Plan compile succedded (BnA inference) ";
1438         fusion_plan_compiled_ = true;
1439       }
1440     } else {
1441       // fusion plan was already compiled...check whether it failed to compile
1442       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1443     }
1444   }
1445 
SetBatchNormInferenceArgs(const void * scale,const void * offset,const void * mean,const void * variance,double epsilon)1446   miopenStatus_t SetBatchNormInferenceArgs(const void* scale,
1447                                            const void* offset, const void* mean,
1448                                            const void* variance,
1449                                            double epsilon) {
1450     float alpha = 1.0;
1451     float beta = 0.0;
1452     return ScopedFusionPlanBase::SetBatchNormInferenceArgs(
1453         k_batchnorm_op_idx, &alpha, &beta, scale, offset, mean, variance,
1454         epsilon);
1455   }
1456 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1457   miopenStatus_t SetActivationForwardArgs(
1458       ScopedActivationDescriptor& activation_descriptor) {
1459     float alpha = 1.0;
1460     float beta = 0.0;
1461 
1462     return ScopedFusionPlanBase::SetActivationForwardArgs(
1463         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1464         activation_descriptor.beta_, activation_descriptor.gamma_);
1465   }
1466 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1467   uint64 GetFusionOpHashValue(
1468       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1469       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1470       ScopedActivationDescriptor& activation_descriptor) {
1471     uint64 hash_value = tensorflow::Hash64("BatchNormActivationInference");
1472 
1473     hash_value = tensorflow::Hash64Combine(
1474         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1475 
1476     hash_value =
1477         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1478 
1479     hash_value = tensorflow::Hash64Combine(
1480         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1481 
1482     hash_value = tensorflow::Hash64Combine(
1483         hash_value, activation_descriptor.GetHashValue());
1484     return hash_value;
1485   }
1486 
1487  private:
1488   const int k_batchnorm_op_idx = 0;
1489   const int k_actv_op_idx = 1;
1490 
1491   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationInference);
1492 };
1493 
1494 // class to represent the BatchNorm+Activation (training-forward) fusion plan
1495 class ScopedFusionPlanBatchNormActivationForward : public ScopedFusionPlanBase {
1496  public:
ScopedFusionPlanBatchNormActivationForward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1497   ScopedFusionPlanBatchNormActivationForward(
1498       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1499       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1500       ScopedActivationDescriptor& activation_descriptor)
1501       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1502                              input_descriptor) {
1503     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1504                                        scale_offset_mean_variance_descriptor,
1505                                        activation_descriptor);
1506 
1507     bool is_compiled = CachedFusionPlans::FindOrCreate(
1508         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1509 
1510     if (!is_compiled) {
1511       miopenFusionOpDescriptor_t batchnorm_op;
1512       auto status = wrap::miopenCreateOpBatchNormForward(
1513           fusion_plan_, &batchnorm_op, miopenBNSpatial,
1514           true /* runningMeanVariance */);
1515 
1516       if (status != miopenStatusSuccess) {
1517         LOG(FATAL) << "call to miopenCreateOpBatchNormForward failed: "
1518                    << ToString(status);
1519       }
1520 
1521       miopenFusionOpDescriptor_t actv_op;
1522       status = wrap::miopenCreateOpActivationForward(
1523           fusion_plan_, &actv_op,
1524           activation_descriptor.miopen_activation_mode_);
1525       if (status != miopenStatusSuccess) {
1526         LOG(FATAL) << "call to miopenCreateOpActivationForward failed: "
1527                    << ToString(status);
1528       }
1529 
1530       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1531       if (status != miopenStatusSuccess) {
1532         VLOG(2) << "call to miopenCompileFusionPlan (BnA forward) failed: "
1533                 << ToString(status);
1534 
1535         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1536       } else {
1537         VLOG(2) << "Fusion Plan compile succedded (BnA forward) ";
1538         fusion_plan_compiled_ = true;
1539       }
1540     } else {
1541       // fusion plan was already compiled...check whether it failed to compile
1542       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1543     }
1544   }
1545 
SetBatchNormForwardArgs(const void * scale,const void * offset,void * batch_mean,void * batch_var,void * saved_mean,void * saved_var,double epsilon)1546   miopenStatus_t SetBatchNormForwardArgs(const void* scale, const void* offset,
1547                                          void* batch_mean, void* batch_var,
1548                                          void* saved_mean, void* saved_var,
1549                                          double epsilon) {
1550     float alpha = 1.0;
1551     float beta = 0.0;
1552     return ScopedFusionPlanBase::SetBatchNormForwardArgs(
1553         k_batchnorm_op_idx, &alpha, &beta, scale, offset, batch_mean, batch_var,
1554         saved_mean, saved_var, epsilon, /*exponential_average_factor=*/1.0);
1555   }
1556 
SetActivationForwardArgs(ScopedActivationDescriptor & activation_descriptor)1557   miopenStatus_t SetActivationForwardArgs(
1558       ScopedActivationDescriptor& activation_descriptor) {
1559     float alpha = 1.0;
1560     float beta = 0.0;
1561 
1562     return ScopedFusionPlanBase::SetActivationForwardArgs(
1563         k_actv_op_idx, &alpha, &beta, activation_descriptor.alpha_,
1564         activation_descriptor.beta_, activation_descriptor.gamma_);
1565   }
1566 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1567   uint64 GetFusionOpHashValue(
1568       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1569       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1570       ScopedActivationDescriptor& activation_descriptor) {
1571     uint64 hash_value = tensorflow::Hash64("BatchNormActivationForward");
1572 
1573     hash_value = tensorflow::Hash64Combine(
1574         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1575 
1576     hash_value =
1577         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1578 
1579     hash_value = tensorflow::Hash64Combine(
1580         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1581 
1582     hash_value = tensorflow::Hash64Combine(
1583         hash_value, activation_descriptor.GetHashValue());
1584     return hash_value;
1585   }
1586 
1587  private:
1588   const int k_batchnorm_op_idx = 0;
1589   const int k_actv_op_idx = 1;
1590 
1591   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationForward);
1592 };
1593 
1594 // class to represent the BatchNorm+Activation (training-backward) fusion plan
1595 class ScopedFusionPlanBatchNormActivationBackward
1596     : public ScopedFusionPlanBase {
1597  public:
ScopedFusionPlanBatchNormActivationBackward(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1598   ScopedFusionPlanBatchNormActivationBackward(
1599       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1600       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1601       ScopedActivationDescriptor& activation_descriptor)
1602       : ScopedFusionPlanBase(miopen_handle, miopenVerticalFusion,
1603                              input_descriptor) {
1604     uint64 hash = GetFusionOpHashValue(miopen_handle, input_descriptor,
1605                                        scale_offset_mean_variance_descriptor,
1606                                        activation_descriptor);
1607 
1608     bool is_compiled = CachedFusionPlans::FindOrCreate(
1609         hash, &fusion_plan_, miopenVerticalFusion, input_descriptor);
1610 
1611     if (!is_compiled) {
1612       miopenFusionOpDescriptor_t batchnorm_op;
1613       auto status = wrap::miopenCreateOpBatchNormBackward(
1614           fusion_plan_, &batchnorm_op, miopenBNSpatial);
1615 
1616       if (status != miopenStatusSuccess) {
1617         LOG(FATAL) << "call to miopenCreateOpBatchNormBackward failed: "
1618                    << ToString(status);
1619       }
1620 
1621       miopenFusionOpDescriptor_t actv_op;
1622       status = wrap::miopenCreateOpActivationBackward(
1623           fusion_plan_, &actv_op,
1624           activation_descriptor.miopen_activation_mode_);
1625       if (status != miopenStatusSuccess) {
1626         LOG(FATAL) << "call to miopenCreateOpActivationBackward failed: "
1627                    << ToString(status);
1628       }
1629 
1630       status = wrap::miopenCompileFusionPlan(miopen_handle_, fusion_plan_);
1631       if (status != miopenStatusSuccess) {
1632         VLOG(2) << "call to miopenCompileFusionPlan (BnA backward) failed: "
1633                 << ToString(status);
1634 
1635         CachedFusionPlans::MarkFusionPlanUnsupported(hash);
1636       } else {
1637         VLOG(2) << "Fusion Plan compile succedded (BnA backward) ";
1638         fusion_plan_compiled_ = true;
1639       }
1640     } else {
1641       // fusion plan was already compiled...check whether it failed to compile
1642       fusion_plan_compiled_ = !CachedFusionPlans::IsUnsupportedFusionPlan(hash);
1643     }
1644   }
1645 
SetBatchNormBackwardArgs(const void * x,const void * scale,const void * offset,const void * saved_mean,const void * saved_var,void * scale_grad,void * offset_grad)1646   miopenStatus_t SetBatchNormBackwardArgs(const void* x, const void* scale,
1647                                           const void* offset,
1648                                           const void* saved_mean,
1649                                           const void* saved_var,
1650                                           void* scale_grad, void* offset_grad) {
1651     float alpha = 1.0;
1652     float beta = 0.0;
1653 
1654     return ScopedFusionPlanBase::SetBatchNormBackwardArgs(
1655         k_batchnorm_op_idx, &alpha, &beta, x, scale, offset, scale_grad,
1656         offset_grad, saved_mean, saved_var);
1657   }
1658 
SetActivationBackwardArgs(ScopedActivationDescriptor & activation_descriptor,const void * y)1659   miopenStatus_t SetActivationBackwardArgs(
1660       ScopedActivationDescriptor& activation_descriptor, const void* y) {
1661     float alpha = 1.0;
1662     float beta = 0.0;
1663 
1664     return ScopedFusionPlanBase::SetActivationBackwardArgs(
1665         k_actv_op_idx, &alpha, &beta, y, activation_descriptor.alpha_,
1666         activation_descriptor.beta_, activation_descriptor.gamma_);
1667   }
1668 
GetFusionOpHashValue(miopenHandle_t miopen_handle,miopenTensorDescriptor_t input_descriptor,miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,ScopedActivationDescriptor & activation_descriptor)1669   uint64 GetFusionOpHashValue(
1670       miopenHandle_t miopen_handle, miopenTensorDescriptor_t input_descriptor,
1671       miopenTensorDescriptor_t scale_offset_mean_variance_descriptor,
1672       ScopedActivationDescriptor& activation_descriptor) {
1673     uint64 hash_value = tensorflow::Hash64("BatchNormActivationBackward");
1674 
1675     hash_value = tensorflow::Hash64Combine(
1676         hash_value, tensorflow::hash<miopenHandle_t>()(miopen_handle));
1677 
1678     hash_value =
1679         tensorflow::Hash64Combine(hash_value, GetHashValue(input_descriptor));
1680 
1681     hash_value = tensorflow::Hash64Combine(
1682         hash_value, GetHashValue(scale_offset_mean_variance_descriptor));
1683 
1684     hash_value = tensorflow::Hash64Combine(
1685         hash_value, activation_descriptor.GetHashValue());
1686     return hash_value;
1687   }
1688 
1689  private:
1690   const int k_batchnorm_op_idx = 0;
1691   const int k_actv_op_idx = 1;
1692 
1693   SE_DISALLOW_COPY_AND_ASSIGN(ScopedFusionPlanBatchNormActivationBackward);
1694 };
1695 
1696 namespace {
ToMIOpenDataType(dnn::DataType data_type,dnn::DataLayout data_layout=dnn::DataLayout::kBatchDepthYX)1697 miopenDataType_t ToMIOpenDataType(
1698     dnn::DataType data_type,
1699     dnn::DataLayout data_layout = dnn::DataLayout::kBatchDepthYX) {
1700   switch (data_type) {
1701     case dnn::DataType::kFloat:
1702       return miopenFloat;
1703     case dnn::DataType::kHalf:
1704       return miopenHalf;
1705     case dnn::DataType::kDouble:
1706     default:
1707       LOG(FATAL) << "Invalid DNN data type: " << static_cast<int>(data_type);
1708   }
1709 }
1710 
ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode)1711 miopenRNNInputMode_t ToMIOpenRnnInputMode(dnn::RnnInputMode input_mode) {
1712   switch (input_mode) {
1713     case dnn::RnnInputMode::kRnnLinearSkip:
1714       return miopenRNNlinear;
1715     case dnn::RnnInputMode::kRnnSkipInput:
1716       return miopenRNNskip;
1717     default:
1718       LOG(FATAL) << "Invalid RNN input mode: " << static_cast<int>(input_mode);
1719   }
1720 }
1721 
ToMIOpenRnnDirectionMode(dnn::RnnDirectionMode direction_mode)1722 miopenRNNDirectionMode_t ToMIOpenRnnDirectionMode(
1723     dnn::RnnDirectionMode direction_mode) {
1724   switch (direction_mode) {
1725     case dnn::RnnDirectionMode::kRnnUnidirectional:
1726       return miopenRNNunidirection;
1727     case dnn::RnnDirectionMode::kRnnBidirectional:
1728       return miopenRNNbidirection;
1729     default:
1730       LOG(FATAL) << "Invalid RNN direction mode: "
1731                  << static_cast<int>(direction_mode);
1732   }
1733 }
1734 
ToMIOpenRnnMode(dnn::RnnMode rnn_mode)1735 miopenRNNMode_t ToMIOpenRnnMode(dnn::RnnMode rnn_mode) {
1736   switch (rnn_mode) {
1737     case dnn::RnnMode::kRnnRelu:
1738       return miopenRNNRELU;
1739     case dnn::RnnMode::kRnnTanh:
1740       return miopenRNNTANH;
1741     case dnn::RnnMode::kRnnLstm:
1742       return miopenLSTM;
1743     case dnn::RnnMode::kRnnGru:
1744       return miopenGRU;
1745     default:
1746       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1747   }
1748 }
1749 
1750 template <typename Base>
1751 class MixinBase : public Base {};
1752 template <>
1753 class MixinBase<void> {};
1754 
1755 }  // namespace
1756 
1757 #define RETURN_IF_MIOPEN_ERROR(STATUS, ...)                              \
1758   if (!SE_PREDICT_TRUE((STATUS) == miopenStatusSuccess)) {               \
1759     string error_msg = absl::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
1760     SetFailure(port::Status(port::error::UNKNOWN, error_msg));           \
1761     LOG(ERROR) << error_msg;                                             \
1762     return;                                                              \
1763   }
1764 
1765 template <typename Base>
1766 class MIOpenDescriptorCommon : public MixinBase<Base> {
1767  public:
ok() const1768   bool ok() const { return status_.ok(); }
Status() const1769   port::Status Status() const { return status_; }
1770 
1771  protected:
SetFailure(const port::Status & status)1772   void SetFailure(const port::Status& status) { status_.Update(status); }
1773   port::Status status_;
1774 };
1775 
1776 class MIOpenRnnParamsDescriptor : public MIOpenDescriptorCommon<void> {
1777  public:
1778   typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
1779   typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
1780   MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,
1781                             const MIOpenRnnDescriptor& rnn_desc);
~MIOpenRnnParamsDescriptor()1782   ~MIOpenRnnParamsDescriptor() {
1783     auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1784     RETURN_IF_MIOPEN_ERROR(status, "Failed to destroy RNN tensor descriptor");
1785   }
handle() const1786   miopenTensorDescriptor_t handle() const {
1787     if (!ok()) return nullptr;
1788     return handle_;
1789   }
params_size_in_bytes() const1790   int64 params_size_in_bytes() const { return params_size_in_bytes_; }
params_weights() const1791   ParamsRegions params_weights() const {
1792     if (!ok()) return ParamsRegions();
1793     return weights_;
1794   }
params_biases() const1795   ParamsRegions params_biases() const {
1796     if (!ok()) return ParamsRegions();
1797     return biases_;
1798   }
1799 
1800  private:
1801   int GetRegionCountPerLayer() const;
1802   miopenTensorDescriptor_t handle_;
1803   const MIOpenRnnDescriptor* rnn_desc_;
1804   int64 params_size_in_bytes_;
1805   ParamsRegions weights_;
1806   ParamsRegions biases_;
1807   port::Status status_;
1808   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnParamsDescriptor);
1809 };
1810 
1811 class MIOpenRnnDescriptor : public MIOpenDescriptorCommon<dnn::RnnDescriptor> {
1812  public:
MIOpenRnnDescriptor(miopenHandle_t miopen_handle,int num_layers,int hidden_size,int input_size,miopenRNNInputMode_t input_mode,miopenRNNDirectionMode_t direction_mode,miopenRNNMode_t rnn_mode,miopenDataType_t data_type,float dropout,uint64 seed,ScratchAllocator * state_allocator)1813   MIOpenRnnDescriptor(miopenHandle_t miopen_handle, int num_layers,
1814                       int hidden_size, int input_size,
1815                       miopenRNNInputMode_t input_mode,
1816                       miopenRNNDirectionMode_t direction_mode,
1817                       miopenRNNMode_t rnn_mode, miopenDataType_t data_type,
1818                       float dropout, uint64 seed,
1819                       ScratchAllocator* state_allocator)
1820       : rnn_desc_(nullptr),
1821         num_layers_(num_layers),
1822         hidden_size_(hidden_size),
1823         input_size_(input_size),
1824         input_mode_(input_mode),
1825         direction_mode_(direction_mode),
1826         rnn_mode_(rnn_mode),
1827         data_type_(data_type) {
1828     // Create the RNN handle
1829     auto status = wrap::miopenCreateRNNDescriptor(&rnn_desc_);
1830     RETURN_IF_MIOPEN_ERROR(status, "Unable to create RNN descriptor");
1831     status = wrap::miopenSetRNNDescriptor(
1832         rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
1833         num_layers /*numLayers*/, input_mode /*inputMode*/,
1834         direction_mode /*direction*/, rnn_mode /*mode*/,
1835         miopenRNNwithBias /*biasMode*/, miopenRNNdefault /*algo*/,
1836         data_type /*dataType*/);
1837     RETURN_IF_MIOPEN_ERROR(status, "Unable to update RNN descriptor");
1838     // Create the params handle.
1839     miopen_params_desc_.reset(
1840         new MIOpenRnnParamsDescriptor(miopen_handle, *this));
1841     if (!miopen_params_desc_->ok()) {
1842       SetFailure(miopen_params_desc_->Status());
1843       return;
1844     }
1845   }
~MIOpenRnnDescriptor()1846   ~MIOpenRnnDescriptor() override {
1847     if (rnn_desc_) {
1848       auto status = wrap::miopenDestroyRNNDescriptor(rnn_desc_);
1849       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN descriptor");
1850     }
1851   }
handle() const1852   miopenRNNDescriptor_t handle() const {
1853     if (!ok()) return nullptr;
1854     return rnn_desc_;
1855   }
num_layers() const1856   int num_layers() const { return num_layers_; }
hidden_size() const1857   int hidden_size() const { return hidden_size_; }
input_size() const1858   int input_size() const { return input_size_; }
input_mode() const1859   miopenRNNInputMode_t input_mode() const { return input_mode_; }
direction_mode() const1860   miopenRNNDirectionMode_t direction_mode() const { return direction_mode_; }
rnn_mode() const1861   miopenRNNMode_t rnn_mode() const { return rnn_mode_; }
data_type() const1862   miopenDataType_t data_type() const { return data_type_; }
ParamsSizeInBytes() const1863   int64 ParamsSizeInBytes() const override {
1864     return miopen_params_desc_->params_size_in_bytes();
1865   }
params_handle() const1866   miopenTensorDescriptor_t params_handle() const {
1867     if (!miopen_params_desc_) return nullptr;
1868     return miopen_params_desc_->handle();
1869   }
ParamsWeightRegions() const1870   ParamsRegions ParamsWeightRegions() const override {
1871     if (!ok()) return ParamsRegions();
1872     return miopen_params_desc_->params_weights();
1873   }
ParamsBiasRegions() const1874   ParamsRegions ParamsBiasRegions() const override {
1875     if (!ok()) return ParamsRegions();
1876     return miopen_params_desc_->params_biases();
1877   }
1878 
1879  private:
1880   miopenRNNDescriptor_t rnn_desc_;
1881   int num_layers_;
1882   int hidden_size_;
1883   int input_size_;
1884   miopenRNNInputMode_t input_mode_;
1885   miopenRNNDirectionMode_t direction_mode_;
1886   miopenRNNMode_t rnn_mode_;
1887   miopenDataType_t data_type_;
1888   port::Status status_;
1889   // no dropout in MIOpen.
1890   // std::unique_ptr<miopenDropoutDescriptor> miopen_dropout_desc_;
1891   std::unique_ptr<MIOpenRnnParamsDescriptor> miopen_params_desc_;
1892   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnDescriptor);
1893 };
1894 
1895 // Get ID of the internal parameter tensor.
1896 //
GetRegionCountPerLayer() const1897 int MIOpenRnnParamsDescriptor::GetRegionCountPerLayer() const {
1898   auto rnn_mode = rnn_desc_->rnn_mode();
1899   switch (rnn_mode) {
1900     case miopenRNNRELU:
1901     case miopenRNNTANH:
1902       return 2;
1903     case miopenLSTM:
1904       return 8;
1905     case miopenGRU:
1906       return 6;
1907     default:
1908       LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
1909   }
1910 }
1911 
1912 class MIOpenRnnSequenceTensorDescriptor
1913     : public MIOpenDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
1914  public:
MIOpenRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,miopenDataType_t data_type)1915   MIOpenRnnSequenceTensorDescriptor(int seq_length, int batch_size,
1916                                     int data_size, miopenDataType_t data_type)
1917       : seq_length_(seq_length),
1918         batch_size_(batch_size),
1919         data_size_(data_size),
1920         data_type_(data_type) {
1921     miopenTensorDescriptor_t handle = nullptr;
1922     if (seq_length <= 0) {
1923       string error_msg =
1924           absl::StrCat("sequence length must be positive: ", seq_length);
1925       LOG(ERROR) << error_msg;
1926       SetFailure(port::Status(port::error::UNKNOWN, error_msg));
1927       return;
1928     }
1929     auto status = wrap::miopenCreateTensorDescriptor(&handle);
1930     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1931     std::array<int, 2> dims = {{batch_size, data_size}};
1932     status = wrap::miopenSetTensorDescriptor(
1933         handle /*tensorDesc*/, data_type /*dataType*/, 2 /*nbDims*/,
1934         dims.data() /*dimA*/, nullptr /*strideA*/);
1935     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1936     // Replicate handle across the number of steps.
1937     handles_.assign(seq_length, handle);
1938   }
1939 
~MIOpenRnnSequenceTensorDescriptor()1940   ~MIOpenRnnSequenceTensorDescriptor() override {
1941     // Only the first one needs to be destroyed. All others are the same.
1942     auto status = wrap::miopenDestroyTensorDescriptor(handles_[0]);
1943     RETURN_IF_MIOPEN_ERROR(status,
1944                            "Failed to destroy sequence tensor descriptor");
1945   }
1946 
handles() const1947   const miopenTensorDescriptor_t* handles() const {
1948     if (!ok()) return nullptr;
1949     CHECK(!handles_.empty()) << "handles cannot be empty";
1950     return handles_.data();
1951   }
1952 
seq_length() const1953   int seq_length() const { return seq_length_; }
batch_size() const1954   int batch_size() const { return batch_size_; }
data_size() const1955   int data_size() const { return data_size_; }
1956 
1957  private:
1958   int seq_length_;
1959   int batch_size_;
1960   int data_size_;
1961   miopenDataType_t data_type_;
1962   std::vector<miopenTensorDescriptor_t> handles_;
1963   port::Status status_;
1964   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnSequenceTensorDescriptor);
1965 };
1966 
1967 class MIOpenRnnStateTensorDescriptor
1968     : public MIOpenDescriptorCommon<dnn::RnnStateTensorDescriptor> {
1969  public:
MIOpenRnnStateTensorDescriptor(int num_layers,int batch_size,int data_size,miopenDataType_t data_type)1970   MIOpenRnnStateTensorDescriptor(int num_layers, int batch_size, int data_size,
1971                                  miopenDataType_t data_type)
1972       : handle_(nullptr),
1973         num_layers_(num_layers),
1974         batch_size_(batch_size),
1975         data_size_(data_size),
1976         data_type_(data_type) {
1977     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
1978     RETURN_IF_MIOPEN_ERROR(status, "Failed to create tensor descriptor");
1979     std::array<int, 3> dims = {{num_layers, batch_size, data_size}};
1980     status = wrap::miopenSetTensorDescriptor(
1981         handle_ /*tensorDesc*/, data_type /*dataType*/, 3 /*nbDims*/,
1982         dims.data() /*dimA*/, nullptr /*strideA*/);
1983     RETURN_IF_MIOPEN_ERROR(status, "Failed to update tensor descriptor");
1984   }
1985 
~MIOpenRnnStateTensorDescriptor()1986   ~MIOpenRnnStateTensorDescriptor() override {
1987     if (!handle_) {
1988       auto status = wrap::miopenDestroyTensorDescriptor(handle_);
1989       RETURN_IF_MIOPEN_ERROR(status, "Unable to destroy RNN state tensor");
1990     }
1991   }
1992 
handle() const1993   miopenTensorDescriptor_t handle() const {
1994     if (!ok()) return nullptr;
1995     return handle_;
1996   }
num_layers() const1997   int num_layers() const { return num_layers_; }
batch_size() const1998   int batch_size() const { return batch_size_; }
data_size() const1999   int data_size() const { return data_size_; }
2000 
2001  private:
2002   miopenTensorDescriptor_t handle_;
2003   int num_layers_;
2004   int batch_size_;
2005   int data_size_;
2006   port::Status status_;
2007   miopenDataType_t data_type_;
2008   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenRnnStateTensorDescriptor);
2009 };
2010 
2011 namespace {
2012 
2013 struct RnnModelDims {
2014   int num_layers = 0;
2015   int batch_size = 0;
2016   int seq_length = 0;
2017   int hidden_size = 0;
2018   int input_size = 0;
2019   int dir_count = 0;
2020 };
2021 
2022 template <class T>
ExtractAndCheckRnnForward(const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,RnnModelDims * model_dims)2023 bool ExtractAndCheckRnnForward(
2024     const MIOpenRnnDescriptor& rnn_desc,
2025     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2026     const DeviceMemory<T>& input_data,
2027     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2028     const DeviceMemory<T>& input_h_data,
2029     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2030     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2031     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2032     const DeviceMemory<T>& output_data,
2033     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2034     const DeviceMemory<T>& output_h_data,
2035     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2036     const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
2037   // extract model parameters
2038   model_dims->num_layers = rnn_desc.num_layers();
2039   model_dims->batch_size = input_desc.batch_size();
2040   model_dims->seq_length = input_desc.seq_length();
2041   model_dims->hidden_size = rnn_desc.hidden_size();
2042   model_dims->input_size = input_desc.data_size();
2043   model_dims->dir_count =
2044       (rnn_desc.direction_mode() == miopenRNNbidirection) ? 2 : 1;
2045 
2046   // check parameters
2047   if (!(input_h_desc.num_layers() ==
2048             model_dims->num_layers * model_dims->dir_count &&
2049         input_h_desc.batch_size() == model_dims->batch_size &&
2050         input_h_desc.data_size() == model_dims->hidden_size)) {
2051     LOG(ERROR) << "Invalid input_h shape";
2052     return false;
2053   }
2054   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
2055         input_h_desc.batch_size() == input_c_desc.batch_size() &&
2056         input_h_desc.data_size() == input_c_desc.data_size())) {
2057     LOG(ERROR) << "Invalid input_c shape";
2058     return false;
2059   }
2060   if (!(output_desc.seq_length() == model_dims->seq_length &&
2061         output_desc.batch_size() == model_dims->batch_size &&
2062         output_desc.data_size() ==
2063             model_dims->hidden_size * model_dims->dir_count)) {
2064     LOG(ERROR) << "Invalid output shape";
2065     return false;
2066   }
2067   if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
2068         input_h_desc.batch_size() == output_h_desc.batch_size() &&
2069         input_h_desc.data_size() == output_h_desc.data_size())) {
2070     LOG(ERROR) << "Invalid output_h shape";
2071     return false;
2072   }
2073   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
2074         input_h_desc.batch_size() == output_c_desc.batch_size() &&
2075         input_h_desc.data_size() == output_c_desc.data_size())) {
2076     LOG(ERROR) << "Invalid output_h shape";
2077     return false;
2078   }
2079 
2080   return true;
2081 }
2082 
CheckRNNParameterSize(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc)2083 bool CheckRNNParameterSize(
2084     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc,
2085     const MIOpenRnnSequenceTensorDescriptor& input_desc) {
2086   size_t params_size_in_bytes = 0;
2087   auto status = wrap::miopenGetRNNParamsSize(
2088       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2089       input_desc.handles()[0] /*xDesc*/, &params_size_in_bytes /*sizeInBytes*/,
2090       rnn_desc.data_type() /*dataType*/);
2091   if (status != miopenStatusSuccess) {
2092     LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
2093     return false;
2094   }
2095   return static_cast<int64>(params_size_in_bytes) ==
2096          rnn_desc.ParamsSizeInBytes();
2097 }
2098 
CreateRnnWorkspace(Stream * stream,miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,ScratchAllocator * workspace_allocator,DeviceMemory<uint8> * workspace)2099 bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle,
2100                         const MIOpenRnnDescriptor& rnn_desc,
2101                         const MIOpenRnnSequenceTensorDescriptor& input_desc,
2102                         ScratchAllocator* workspace_allocator,
2103                         DeviceMemory<uint8>* workspace) {
2104   // Query the workspace size.
2105   size_t workspace_size_in_bytes = 0;
2106   auto status = wrap::miopenGetRNNWorkspaceSize(
2107       miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2108       input_desc.seq_length() /*seqLength*/, input_desc.handles() /*xDesc*/,
2109       &workspace_size_in_bytes /*sizeInBytes*/);
2110   if (status != miopenStatusSuccess) {
2111     LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
2112     return false;
2113   }
2114   // Allocate the workspace.
2115   if (workspace_size_in_bytes > 0) {
2116     auto allocated =
2117         workspace_allocator->AllocateBytes(workspace_size_in_bytes);
2118     if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
2119       LOG(ERROR) << "Failed to allocate RNN workspace";
2120 
2121       return false;
2122     }
2123     stream->ThenMemZero(workspace, workspace_size_in_bytes);
2124   } else {
2125     *workspace = DeviceMemory<uint8>();
2126   }
2127   return true;
2128 }
2129 
2130 }  // namespace
2131 
2132 template <class T>
DoRnnForwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,DeviceMemory<T> * output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,DeviceMemory<T> * output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,DeviceMemory<T> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)2133 bool MIOpenSupport::DoRnnForwardImpl(
2134     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2135     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2136     const DeviceMemory<T>& input_data,
2137     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2138     const DeviceMemory<T>& input_h_data,
2139     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2140     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2141     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2142     DeviceMemory<T>* output_data,
2143     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2144     DeviceMemory<T>* output_h_data,
2145     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2146     DeviceMemory<T>* output_c_data, bool is_training,
2147     ScratchAllocator* reserve_space_allocator,
2148     ScratchAllocator* workspace_allocator) {
2149   // extract model parameters
2150   RnnModelDims model_dims;
2151   bool res = ExtractAndCheckRnnForward(
2152       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2153       input_c_desc, input_c_data, params, output_desc, *output_data,
2154       output_h_desc, *output_h_data, output_c_desc, *output_c_data,
2155       &model_dims);
2156   if (!res) {
2157     LOG(ERROR) << "Invalid parameters for RNN Model";
2158     return false;
2159   }
2160 
2161   auto miopen = miopen_->GetHandle(parent_, stream);
2162 
2163   // check params size
2164 
2165   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2166     LOG(ERROR) << "Invalid parameters";
2167     return false;
2168   }
2169 
2170   // create the workspace
2171   DeviceMemory<uint8> workspace;
2172   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2173                           workspace_allocator, &workspace)) {
2174     LOG(ERROR) << "Unable to create rnn workspace";
2175 
2176     return false;
2177   }
2178 
2179   // query the reserve space size
2180   // allocate the reserve space
2181   DeviceMemory<uint8> reserve_space;
2182   if (is_training) {
2183     size_t reserve_space_size_in_bytes = 0;
2184     auto status = wrap::miopenGetRNNTrainingReserveSize(
2185         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2186         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2187         &reserve_space_size_in_bytes /*sizeInBytes*/);
2188     if (status != miopenStatusSuccess) {
2189       LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
2190       return false;
2191     }
2192 
2193     if (reserve_space_size_in_bytes > 0) {
2194       auto allocated =
2195           reserve_space_allocator->AllocateBytes(reserve_space_size_in_bytes);
2196       if (!allocated.ok() ||
2197           (reserve_space = allocated.ValueOrDie()) == nullptr) {
2198         LOG(ERROR) << "Fail to allocate RNN reserve space";
2199         return false;
2200       }
2201       stream->ThenMemZero(&reserve_space, reserve_space_size_in_bytes);
2202     }
2203   }
2204 
2205   // make the forward call
2206   if (!is_training) {
2207     auto status = wrap::miopenRNNForwardInference(
2208         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2209         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2210         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2211         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2212         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2213         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2214         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2215         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2216         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2217         workspace.size() /*workSpaceSizeInBytes*/);
2218 
2219     if (status != miopenStatusSuccess) {
2220       LOG(ERROR) << "Failed to call miopenRNNForwardInference: "
2221                  << ToString(status);
2222       return false;
2223     }
2224   } else {
2225     auto status = wrap::miopenRNNForwardTraining(
2226         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2227         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2228         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2229         input_h_data.opaque() /*hx*/, input_c_desc.handle() /*cxDesc*/,
2230         input_c_data.opaque() /*cx*/, rnn_desc.params_handle() /*wDesc*/,
2231         params.opaque() /*w*/, output_desc.handles() /*yDesc*/,
2232         output_data->opaque() /*y*/, output_h_desc.handle() /*hyDesc*/,
2233         output_h_data->opaque() /*hy*/, output_c_desc.handle() /*cyDesc*/,
2234         output_c_data->opaque() /*cy*/, workspace.opaque() /*workspace*/,
2235         workspace.size() /*workSpaceSizeInBytes*/,
2236         reserve_space.opaque() /*reserveSpace*/,
2237         reserve_space.size() /*reserveSpaceSizeInBytes*/);
2238     if (status != miopenStatusSuccess) {
2239       LOG(ERROR) << "Failed to call miopenRNNForwardTraining"
2240                  << ToString(status);
2241       return false;
2242     }
2243   }
2244   return true;
2245 }
2246 
2247 template <class T>
DoRnnBackwardImpl(Stream * stream,const MIOpenRnnDescriptor & rnn_desc,const MIOpenRnnSequenceTensorDescriptor & input_desc,const DeviceMemory<T> & input_data,const MIOpenRnnStateTensorDescriptor & input_h_desc,const DeviceMemory<T> & input_h_data,const MIOpenRnnStateTensorDescriptor & input_c_desc,const DeviceMemory<T> & input_c_data,const DeviceMemory<T> & params,const MIOpenRnnSequenceTensorDescriptor & output_desc,const DeviceMemory<T> & output_data,const MIOpenRnnStateTensorDescriptor & output_h_desc,const DeviceMemory<T> & output_h_data,const MIOpenRnnStateTensorDescriptor & output_c_desc,const DeviceMemory<T> & output_c_data,const DeviceMemory<T> & output_backprop_data,const DeviceMemory<T> & output_h_backprop_data,const DeviceMemory<T> & output_c_backprop_data,DeviceMemory<T> * input_backprop_data,DeviceMemory<T> * input_h_backprop_data,DeviceMemory<T> * input_c_backprop_data,DeviceMemory<T> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)2248 bool MIOpenSupport::DoRnnBackwardImpl(
2249     Stream* stream, const MIOpenRnnDescriptor& rnn_desc,
2250     const MIOpenRnnSequenceTensorDescriptor& input_desc,
2251     const DeviceMemory<T>& input_data,
2252     const MIOpenRnnStateTensorDescriptor& input_h_desc,
2253     const DeviceMemory<T>& input_h_data,
2254     const MIOpenRnnStateTensorDescriptor& input_c_desc,
2255     const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
2256     const MIOpenRnnSequenceTensorDescriptor& output_desc,
2257     const DeviceMemory<T>& output_data,
2258     const MIOpenRnnStateTensorDescriptor& output_h_desc,
2259     const DeviceMemory<T>& output_h_data,
2260     const MIOpenRnnStateTensorDescriptor& output_c_desc,
2261     const DeviceMemory<T>& output_c_data,
2262     const DeviceMemory<T>& output_backprop_data,
2263     const DeviceMemory<T>& output_h_backprop_data,
2264     const DeviceMemory<T>& output_c_backprop_data,
2265     DeviceMemory<T>* input_backprop_data,
2266     DeviceMemory<T>* input_h_backprop_data,
2267     DeviceMemory<T>* input_c_backprop_data,
2268     DeviceMemory<T>* params_backprop_data,
2269     DeviceMemory<uint8>* reserve_space_data,
2270     ScratchAllocator* workspace_allocator) {
2271   // extract model parameters
2272   RnnModelDims model_dims;
2273   bool res = ExtractAndCheckRnnForward(
2274       rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
2275       input_c_desc, input_c_data, params, output_desc, output_data,
2276       output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
2277   if (!res) {
2278     LOG(ERROR) << "Invalid parameters for RNN Model";
2279     return false;
2280   }
2281 
2282   auto miopen = miopen_->GetHandle(parent_, stream);
2283 
2284   // check params size
2285 
2286   if (!CheckRNNParameterSize(miopen.handle(), rnn_desc, input_desc)) {
2287     LOG(ERROR) << "Invalid parameters";
2288     return false;
2289   }
2290 
2291   // create the workspace
2292   DeviceMemory<uint8> workspace;
2293   if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc,
2294                           workspace_allocator, &workspace)) {
2295     LOG(ERROR) << "Unable to create rnn workspace";
2296     return false;
2297   }
2298 
2299   // workaround for missing initialization support in MIOpen.
2300   // TODO: remove this when MIOpen is ready.
2301   auto type_size = std::is_same<T, Eigen::half>::value ? 2 : sizeof(T);
2302   auto size_data = input_desc.seq_length() * input_desc.batch_size() *
2303                    input_desc.data_size();
2304   if ((size_data > 0) && (input_backprop_data->opaque() != nullptr))
2305     stream->ThenMemZero(input_backprop_data, size_data * type_size);
2306 
2307   size_data = input_h_desc.num_layers() * input_h_desc.batch_size() *
2308               input_h_desc.data_size();
2309   if ((size_data > 0) && (input_h_backprop_data->opaque() != nullptr))
2310     stream->ThenMemZero(input_h_backprop_data, size_data * type_size);
2311 
2312   size_data = input_c_desc.num_layers() * input_c_desc.batch_size() *
2313               input_c_desc.data_size();
2314   if ((size_data > 0) && (input_c_backprop_data->opaque() != nullptr))
2315     stream->ThenMemZero(input_c_backprop_data, size_data * type_size);
2316 
2317   // make the backward data call
2318   auto status = wrap::miopenRNNBackwardData(
2319       miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2320       model_dims.seq_length /*seqLength*/, output_desc.handles() /*yDesc*/,
2321       output_data.opaque() /*y*/, output_desc.handles() /*dyDesc*/,
2322       output_backprop_data.opaque() /*dy*/, output_h_desc.handle() /*dhyDesc*/,
2323       output_h_backprop_data.opaque() /*dhy*/,
2324       output_c_desc.handle() /*dcyDesc*/,
2325       output_c_backprop_data.opaque() /*dcy*/,
2326       rnn_desc.params_handle() /*wDesc*/, params.opaque() /*w*/,
2327       input_h_desc.handle() /*hxDesc*/, input_h_data.opaque() /*hx*/,
2328       input_c_desc.handle() /*cxDesc*/, input_c_data.opaque() /*cx*/,
2329       input_desc.handles() /*dxDesc*/, input_backprop_data->opaque() /*dx*/,
2330       input_h_desc.handle() /*dhxDesc*/,
2331       input_h_backprop_data->opaque() /*dhx*/,
2332       input_c_desc.handle() /*dcxDesc*/,
2333       input_c_backprop_data->opaque() /*dcx*/, workspace.opaque() /*workspace*/,
2334       workspace.size() /*workSpaceSizeInBytes*/,
2335       reserve_space_data->opaque() /*reserveSpace*/,
2336       reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2337   if (status != miopenStatusSuccess) {
2338     LOG(ERROR) << "Failed to call miopenRNNBackwardData: " << ToString(status);
2339     return false;
2340   }
2341 
2342   if (params_backprop_data != nullptr) {
2343     // Clear the dw to zeros.
2344     stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
2345     // make the backward weight call
2346     status = wrap::miopenRNNBackwardWeights(
2347         miopen.handle() /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2348         model_dims.seq_length /*seqLength*/, input_desc.handles() /*xDesc*/,
2349         input_data.opaque() /*x*/, input_h_desc.handle() /*hxDesc*/,
2350         input_h_data.opaque() /*hx*/, output_desc.handles() /*yDesc*/,
2351         output_data.opaque() /*y*/, rnn_desc.params_handle() /*dwDesc*/,
2352         params_backprop_data->opaque() /*dw*/, workspace.opaque() /*workspace*/,
2353         workspace.size() /*workSpaceSizeInBytes*/,
2354         reserve_space_data->opaque() /*reserveSpace*/,
2355         reserve_space_data->size() /*reserveSpaceSizeInBytes*/);
2356     if (status != miopenStatusSuccess) {
2357       LOG(ERROR) << "Failed to call miopenRNNBackwardWeights: "
2358                  << ToString(status);
2359       return false;
2360     }
2361   }
2362 
2363   return true;
2364 }
2365 
MIOpenRnnParamsDescriptor(miopenHandle_t miopen_handle,const MIOpenRnnDescriptor & rnn_desc)2366 MIOpenRnnParamsDescriptor::MIOpenRnnParamsDescriptor(
2367     miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc)
2368     : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) {
2369   miopenTensorDescriptor_t input_desc = nullptr;
2370   {
2371     // Query the params size.
2372     auto status = wrap::miopenCreateTensorDescriptor(&input_desc);
2373     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to create tensor descriptor");
2374     std::array<int, 2> dims = {{1, rnn_desc.input_size()}};
2375     status = wrap::miopenSetTensorDescriptor(
2376         input_desc /*tensorDesc*/, rnn_desc.data_type() /*dataType*/,
2377         2 /*nbDims*/, dims.data() /*dimA*/, nullptr /*strideA*/);
2378     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to set tensor descriptor");
2379 
2380     size_t params_size = 0;
2381     status = wrap::miopenGetRNNParamsSize(
2382         miopen_handle /*handle*/, rnn_desc.handle() /*rnnDesc*/,
2383         input_desc /*xDesc*/, &params_size /*sizeInBytes*/,
2384         rnn_desc.data_type() /*dataType*/);
2385     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to get RNN parameter size");
2386     params_size_in_bytes_ = static_cast<int64>(params_size);
2387   }
2388 
2389   {
2390     // Create the params descriptor.
2391     auto status = wrap::miopenCreateTensorDescriptor(&handle_);
2392     RETURN_IF_MIOPEN_ERROR(status,
2393                            "MIOpen fails to create RNN params descriptor");
2394     status = wrap::miopenGetRNNParamsDescriptor(miopen_handle,
2395                                                 rnn_desc.handle(), input_desc,
2396                                                 handle_, rnn_desc.data_type());
2397     RETURN_IF_MIOPEN_ERROR(status,
2398                            "MIOpen fails to update RNN filter descriptor");
2399   }
2400   {
2401     // Release the dummy input tensor descriptor.
2402     auto status = wrap::miopenDestroyTensorDescriptor(input_desc);
2403     RETURN_IF_MIOPEN_ERROR(status, "MIOpen fails to destroy tensor descriptor");
2404   }
2405 }
2406 
2407 class MIOpenCTCLossDescriptor {
2408  public:
MIOpenCTCLossDescriptor(miopenDataType_t data_type)2409   explicit MIOpenCTCLossDescriptor(miopenDataType_t data_type) {
2410     auto status = wrap::miopenCreateCTCLossDescriptor(&handle_);
2411     if (status != miopenStatusSuccess) {
2412       LOG(FATAL) << "call to miopenCreateCTCLossDescriptor failed: "
2413                  << ToString(status);
2414     }
2415 
2416     bool apply_softmax_layer = true;
2417     status = wrap::miopenSetCTCLossDescriptor(handle_, data_type, 0,
2418                                               apply_softmax_layer);
2419     if (status != miopenStatusSuccess) {
2420       LOG(FATAL) << "call to miopenSetCTCLossDescriptor failed: "
2421                  << ToString(status);
2422     }
2423   }
2424 
~MIOpenCTCLossDescriptor()2425   ~MIOpenCTCLossDescriptor() {
2426     auto status = wrap::miopenDestroyCTCLossDescriptor(handle_);
2427     if (status != miopenStatusSuccess) {
2428       LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
2429                  << ToString(status);
2430     }
2431   }
2432 
handle() const2433   miopenCTCLossDescriptor_t handle() const { return handle_; }
2434 
2435  private:
2436   miopenCTCLossDescriptor_t handle_;  // Owned
2437 
2438   SE_DISALLOW_COPY_AND_ASSIGN(MIOpenCTCLossDescriptor);
2439 };
2440 
DoPrepareForCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const dnn::RnnStateTensorDescriptor & grads_desc,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,ScratchAllocator * scratch_allocator,DeviceMemory<uint8> * scratch_memory,int * ctc_loss_algo_id)2441 port::Status MIOpenSupport::DoPrepareForCtcLoss(
2442     Stream* stream, dnn::DataType element_type,
2443     const dnn::RnnStateTensorDescriptor& probs_desc,
2444     const dnn::RnnStateTensorDescriptor& grads_desc,
2445     absl::Span<const int> labels_data,
2446     absl::Span<const int> labels_lengths_data,
2447     absl::Span<const int> input_lengths_data,
2448     ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch_memory,
2449     int* ctc_loss_algo_id) {
2450   auto miopen = miopen_->GetHandle(parent_, stream);
2451 
2452   MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
2453 
2454   // Query the workspace size.
2455   size_t workspace_size_in_bytes = 0;
2456 
2457   const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
2458       static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
2459 
2460   const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
2461       static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
2462 
2463   auto status = wrap::miopenGetCTCLossWorkspaceSize(
2464       miopen.handle(), miopen_probs_desc.handle(), miopen_grads_desc.handle(),
2465       labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
2466       MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, miopen_ctc_loss_desc.handle(),
2467       &workspace_size_in_bytes);
2468 
2469   if (status != miopenStatusSuccess) {
2470     LOG(FATAL) << "call to miopenDestroyCTCLossDescriptor failed: "
2471                << ToString(status);
2472     return port::InternalError(
2473         "Failed to determine scratch memory size for MIOpen CTC Loss");
2474   }
2475 
2476   *scratch_memory = DeviceMemory<uint8>();
2477 
2478   // Allocate the workspace.
2479   if (workspace_size_in_bytes != 0) {
2480     if (scratch_allocator == nullptr) {
2481       return port::InternalError(
2482           absl::StrCat("An allocator must be specified when scratch memory is "
2483                        "needed"));
2484     }
2485     auto scratch_or = scratch_allocator->AllocateBytes(workspace_size_in_bytes);
2486     if (scratch_or.ok()) {
2487       *scratch_memory = scratch_or.ValueOrDie();
2488     } else {
2489       LOG(ERROR)
2490           << "Failed to allocate scratch memory - "
2491           << scratch_or.status().error_message() << "\n"
2492           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
2493              "larger number (e.g. 8192) to increase the max memory limit.\n"
2494           << "\tIncreasing the max memory limit might help resolve this "
2495              "error";
2496       return port::InternalError(absl::StrCat(
2497           "Failed to allocate scratch memory for MIOpen CTC Loss, of size: ",
2498           workspace_size_in_bytes));
2499     }
2500   }
2501 
2502   return port::Status::OK();
2503 }
2504 
DoCtcLossImpl(Stream * stream,const MIOpenRnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const MIOpenRnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,const MIOpenCTCLossDescriptor & ctc_loss_desc,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2505 port::Status MIOpenSupport::DoCtcLossImpl(
2506     Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc,
2507     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2508     absl::Span<const int> labels_lengths_data,
2509     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2510     const MIOpenRnnStateTensorDescriptor& grads_desc,
2511     DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc,
2512     DeviceMemory<uint8> scratch_memory, int ctc_loss_algo_id) {
2513   auto miopen = miopen_->GetHandle(parent_, stream);
2514 
2515   int kNumTimestamps = probs_desc.num_layers();
2516   int kBatchSize = probs_desc.batch_size();
2517   int kNumLabels = probs_desc.data_size();
2518   int total_size = kNumLabels * kNumTimestamps * kBatchSize;
2519   (void)total_size;
2520 
2521   auto status = wrap::miopenCTCLoss(
2522       miopen.handle(), probs_desc.handle(), probs_data.opaque(),
2523       labels_data.data(), labels_lengths_data.data(), input_lengths_data.data(),
2524       costs_data.opaque(), grads_desc.handle(), grads_data.opaque(),
2525       MIOPEN_CTC_LOSS_ALGO_DETERMINISTIC, ctc_loss_desc.handle(),
2526       scratch_memory.opaque(), scratch_memory.size());
2527   if (status != miopenStatusSuccess) {
2528     LOG(FATAL) << "call to miopenCTCLoss failed: " << ToString(status);
2529     return port::InternalError("Failure during MIOpen CTC Loss");
2530   }
2531 
2532   return port::Status::OK();
2533 }
2534 
DoCtcLoss(Stream * stream,dnn::DataType element_type,const dnn::RnnStateTensorDescriptor & probs_desc,const DeviceMemoryBase probs_data,absl::Span<const int> labels_data,absl::Span<const int> labels_lengths_data,absl::Span<const int> input_lengths_data,DeviceMemoryBase costs_data,const dnn::RnnStateTensorDescriptor & grads_desc,DeviceMemoryBase grads_data,DeviceMemory<uint8> scratch_memory,int ctc_loss_algo_id)2535 port::Status MIOpenSupport::DoCtcLoss(
2536     Stream* stream, dnn::DataType element_type,
2537     const dnn::RnnStateTensorDescriptor& probs_desc,
2538     const DeviceMemoryBase probs_data, absl::Span<const int> labels_data,
2539     absl::Span<const int> labels_lengths_data,
2540     absl::Span<const int> input_lengths_data, DeviceMemoryBase costs_data,
2541     const dnn::RnnStateTensorDescriptor& grads_desc,
2542     DeviceMemoryBase grads_data, DeviceMemory<uint8> scratch_memory,
2543     int ctc_loss_algo_id) {
2544   // Current MIOPen CTC Loss only supports the float datatype
2545   if (element_type != dnn::DataType::kFloat) {
2546     return port::Status(port::error::INVALID_ARGUMENT,
2547                         "MIOpenCTCLossDescriptor is supported only when the "
2548                         "DataType is float");
2549   }
2550 
2551   MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type));
2552 
2553   const MIOpenRnnStateTensorDescriptor& miopen_probs_desc =
2554       static_cast<const MIOpenRnnStateTensorDescriptor&>(probs_desc);
2555 
2556   const MIOpenRnnStateTensorDescriptor& miopen_grads_desc =
2557       static_cast<const MIOpenRnnStateTensorDescriptor&>(grads_desc);
2558 
2559   return DoCtcLossImpl(stream, miopen_probs_desc, probs_data, labels_data,
2560                        labels_lengths_data, input_lengths_data, costs_data,
2561                        miopen_grads_desc, grads_data, miopen_ctc_loss_desc,
2562                        scratch_memory, ctc_loss_algo_id);
2563 }
2564 
2565 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers,int hidden_size,int input_size,int cell_size,int batch_size,dnn::RnnInputMode input_mode,dnn::RnnDirectionMode direction_mode,dnn::RnnMode rnn_mode,dnn::DataType data_type,const dnn::AlgorithmConfig & algorithm_config,float dropout,uint64 seed,ScratchAllocator * state_allocator,bool use_padded_io)2566 MIOpenSupport::createRnnDescriptor(
2567     int num_layers, int hidden_size, int input_size, int cell_size,
2568     int batch_size, dnn::RnnInputMode input_mode,
2569     dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
2570     dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
2571     float dropout, uint64 seed, ScratchAllocator* state_allocator,
2572     bool use_padded_io) {
2573   // ROCM TODO: batch_size is used in dynamic persistent RNN algorithm and is
2574   // not supported by MIOpen now.
2575   if (use_padded_io) {
2576     return port::Status(port::error::INVALID_ARGUMENT,
2577                         "ROCm MIOpen only supports packed input output.");
2578   }
2579 
2580   bool use_projection = cell_size != 0 && hidden_size < cell_size;
2581   if (use_projection) {
2582     return port::Status(
2583         port::error::INVALID_ARGUMENT,
2584         "ROCm MIOpen does not support RNN ProjectionLayers yet.");
2585   }
2586 
2587   auto miopen = miopen_->GetHandle(parent_, nullptr);
2588   std::unique_ptr<MIOpenRnnDescriptor> rnn_desc(new MIOpenRnnDescriptor(
2589       miopen.handle(), num_layers, hidden_size, input_size,
2590       ToMIOpenRnnInputMode(input_mode),
2591       ToMIOpenRnnDirectionMode(direction_mode), ToMIOpenRnnMode(rnn_mode),
2592       ToMIOpenDataType(data_type), dropout, seed, state_allocator));
2593   if (!rnn_desc->ok()) {
2594     return rnn_desc->Status();
2595   }
2596   return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
2597       std::move(rnn_desc));
2598 }
2599 
2600 port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
createRnnSequenceTensorDescriptor(int seq_length,int batch_size,int data_size,dnn::DataType data_type)2601 MIOpenSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
2602                                                  int data_size,
2603                                                  dnn::DataType data_type) {
2604   std::unique_ptr<MIOpenRnnSequenceTensorDescriptor> seq_desc(
2605       new MIOpenRnnSequenceTensorDescriptor(seq_length, batch_size, data_size,
2606                                             ToMIOpenDataType(data_type)));
2607   if (!seq_desc->ok()) {
2608     return seq_desc->Status();
2609   }
2610   return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
2611       std::move(seq_desc));
2612 }
2613 
2614 port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
createRnnStateTensorDescriptor(int num_layer,int batch_size,int data_size,dnn::DataType data_type)2615 MIOpenSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
2616                                               int data_size,
2617                                               dnn::DataType data_type) {
2618   std::unique_ptr<MIOpenRnnStateTensorDescriptor> state_desc(
2619       new MIOpenRnnStateTensorDescriptor(num_layer, batch_size, data_size,
2620                                          ToMIOpenDataType(data_type)));
2621   if (!state_desc->ok()) {
2622     return state_desc->Status();
2623   }
2624   return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
2625       std::move(state_desc));
2626 }
2627 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<Eigen::half> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<Eigen::half> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<Eigen::half> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2628 bool MIOpenSupport::DoRnnForward(
2629     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2630     const dnn::RnnSequenceTensorDescriptor& input_desc,
2631     const DeviceMemory<Eigen::half>& input_data,
2632     const DeviceMemory<int>& seq_lengths_data,
2633     const dnn::RnnStateTensorDescriptor& input_h_desc,
2634     const DeviceMemory<Eigen::half>& input_h_data,
2635     const dnn::RnnStateTensorDescriptor& input_c_desc,
2636     const DeviceMemory<Eigen::half>& input_c_data,
2637     const DeviceMemory<Eigen::half>& params,
2638     const dnn::RnnSequenceTensorDescriptor& output_desc,
2639     DeviceMemory<Eigen::half>* output_data,
2640     const dnn::RnnStateTensorDescriptor& output_h_desc,
2641     DeviceMemory<Eigen::half>* output_h_data,
2642     const dnn::RnnStateTensorDescriptor& output_c_desc,
2643     DeviceMemory<Eigen::half>* output_c_data, bool is_training,
2644     ScratchAllocator* reserve_space_allocator,
2645     ScratchAllocator* workspace_allocator,
2646     dnn::ProfileResult* output_profile_result) {
2647   // ROCM TODO: output_profile_result is ignore for now
2648 
2649   const MIOpenRnnDescriptor& miopen_rnn_desc =
2650       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2651   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2652       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2653   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2654       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2655   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2656       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2657   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2658       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2659   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2660       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2661   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2662       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2663 
2664   return DoRnnForwardImpl<Eigen::half>(
2665       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2666       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2667       params, miopen_output_desc, output_data, miopen_output_h_desc,
2668       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2669       reserve_space_allocator, workspace_allocator);
2670 }
2671 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<float> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<float> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<float> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2672 bool MIOpenSupport::DoRnnForward(
2673     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2674     const dnn::RnnSequenceTensorDescriptor& input_desc,
2675     const DeviceMemory<float>& input_data,
2676     const DeviceMemory<int>& seq_lengths_data,
2677     const dnn::RnnStateTensorDescriptor& input_h_desc,
2678     const DeviceMemory<float>& input_h_data,
2679     const dnn::RnnStateTensorDescriptor& input_c_desc,
2680     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2681     const dnn::RnnSequenceTensorDescriptor& output_desc,
2682     DeviceMemory<float>* output_data,
2683     const dnn::RnnStateTensorDescriptor& output_h_desc,
2684     DeviceMemory<float>* output_h_data,
2685     const dnn::RnnStateTensorDescriptor& output_c_desc,
2686     DeviceMemory<float>* output_c_data, bool is_training,
2687     ScratchAllocator* reserve_space_allocator,
2688     ScratchAllocator* workspace_allocator,
2689     dnn::ProfileResult* output_profile_result) {
2690   // ROCM TODO: output_profile_result is ignore for now
2691 
2692   const MIOpenRnnDescriptor& miopen_rnn_desc =
2693       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2694   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2695       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2696   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2697       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2698   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2699       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2700   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2701       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2702   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2703       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2704   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2705       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2706 
2707   return DoRnnForwardImpl<float>(
2708       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2709       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2710       params, miopen_output_desc, output_data, miopen_output_h_desc,
2711       output_h_data, miopen_output_c_desc, output_c_data, is_training,
2712       reserve_space_allocator, workspace_allocator);
2713 }
2714 
DoRnnForward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,DeviceMemory<double> * output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,DeviceMemory<double> * output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,DeviceMemory<double> * output_c_data,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2715 bool MIOpenSupport::DoRnnForward(
2716     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2717     const dnn::RnnSequenceTensorDescriptor& input_desc,
2718     const DeviceMemory<double>& input_data,
2719     const DeviceMemory<int>& seq_lengths_data,
2720     const dnn::RnnStateTensorDescriptor& input_h_desc,
2721     const DeviceMemory<double>& input_h_data,
2722     const dnn::RnnStateTensorDescriptor& input_c_desc,
2723     const DeviceMemory<double>& input_c_data,
2724     const DeviceMemory<double>& params,
2725     const dnn::RnnSequenceTensorDescriptor& output_desc,
2726     DeviceMemory<double>* output_data,
2727     const dnn::RnnStateTensorDescriptor& output_h_desc,
2728     DeviceMemory<double>* output_h_data,
2729     const dnn::RnnStateTensorDescriptor& output_c_desc,
2730     DeviceMemory<double>* output_c_data, bool is_training,
2731     ScratchAllocator* reserve_space_allocator,
2732     ScratchAllocator* workspace_allocator,
2733     dnn::ProfileResult* output_profile_result) {
2734   LOG(ERROR) << "miopen does not support double type RNN fwd yet";
2735   return false;
2736 }
2737 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<Eigen::half> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<Eigen::half> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<Eigen::half> & input_c_data,const DeviceMemory<Eigen::half> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<Eigen::half> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<Eigen::half> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<Eigen::half> & output_c_data,const DeviceMemory<Eigen::half> & output_backprop_data,const DeviceMemory<Eigen::half> & output_h_backprop_data,const DeviceMemory<Eigen::half> & output_c_backprop_data,DeviceMemory<Eigen::half> * input_backprop_data,DeviceMemory<Eigen::half> * input_h_backprop_data,DeviceMemory<Eigen::half> * input_c_backprop_data,DeviceMemory<Eigen::half> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2738 bool MIOpenSupport::DoRnnBackward(
2739     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2740     const dnn::RnnSequenceTensorDescriptor& input_desc,
2741     const DeviceMemory<Eigen::half>& input_data,
2742     const DeviceMemory<int>& seq_lengths_data,
2743     const dnn::RnnStateTensorDescriptor& input_h_desc,
2744     const DeviceMemory<Eigen::half>& input_h_data,
2745     const dnn::RnnStateTensorDescriptor& input_c_desc,
2746     const DeviceMemory<Eigen::half>& input_c_data,
2747     const DeviceMemory<Eigen::half>& params,
2748     const dnn::RnnSequenceTensorDescriptor& output_desc,
2749     const DeviceMemory<Eigen::half>& output_data,
2750     const dnn::RnnStateTensorDescriptor& output_h_desc,
2751     const DeviceMemory<Eigen::half>& output_h_data,
2752     const dnn::RnnStateTensorDescriptor& output_c_desc,
2753     const DeviceMemory<Eigen::half>& output_c_data,
2754     const DeviceMemory<Eigen::half>& output_backprop_data,
2755     const DeviceMemory<Eigen::half>& output_h_backprop_data,
2756     const DeviceMemory<Eigen::half>& output_c_backprop_data,
2757     DeviceMemory<Eigen::half>* input_backprop_data,
2758     DeviceMemory<Eigen::half>* input_h_backprop_data,
2759     DeviceMemory<Eigen::half>* input_c_backprop_data,
2760     DeviceMemory<Eigen::half>* params_backprop_data,
2761     DeviceMemory<uint8>* reserve_space_data,
2762     ScratchAllocator* workspace_allocator,
2763     dnn::ProfileResult* output_profile_result) {
2764   // ROCM TODO: output_profile_result is ignore for now
2765 
2766   const MIOpenRnnDescriptor& miopen_rnn_desc =
2767       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2768   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2769       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2770   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2771       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2772   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2773       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2774   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2775       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2776   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2777       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2778   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2779       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2780 
2781   return DoRnnBackwardImpl<Eigen::half>(
2782       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2783       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2784       params, miopen_output_desc, output_data, miopen_output_h_desc,
2785       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2786       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2787       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2788       reserve_space_data, workspace_allocator);
2789 }
2790 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<float> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<float> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<float> & input_c_data,const DeviceMemory<float> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<float> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<float> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<float> & output_c_data,const DeviceMemory<float> & output_backprop_data,const DeviceMemory<float> & output_h_backprop_data,const DeviceMemory<float> & output_c_backprop_data,DeviceMemory<float> * input_backprop_data,DeviceMemory<float> * input_h_backprop_data,DeviceMemory<float> * input_c_backprop_data,DeviceMemory<float> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2791 bool MIOpenSupport::DoRnnBackward(
2792     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2793     const dnn::RnnSequenceTensorDescriptor& input_desc,
2794     const DeviceMemory<float>& input_data,
2795     const DeviceMemory<int>& seq_lengths_data,
2796     const dnn::RnnStateTensorDescriptor& input_h_desc,
2797     const DeviceMemory<float>& input_h_data,
2798     const dnn::RnnStateTensorDescriptor& input_c_desc,
2799     const DeviceMemory<float>& input_c_data, const DeviceMemory<float>& params,
2800     const dnn::RnnSequenceTensorDescriptor& output_desc,
2801     const DeviceMemory<float>& output_data,
2802     const dnn::RnnStateTensorDescriptor& output_h_desc,
2803     const DeviceMemory<float>& output_h_data,
2804     const dnn::RnnStateTensorDescriptor& output_c_desc,
2805     const DeviceMemory<float>& output_c_data,
2806     const DeviceMemory<float>& output_backprop_data,
2807     const DeviceMemory<float>& output_h_backprop_data,
2808     const DeviceMemory<float>& output_c_backprop_data,
2809     DeviceMemory<float>* input_backprop_data,
2810     DeviceMemory<float>* input_h_backprop_data,
2811     DeviceMemory<float>* input_c_backprop_data,
2812     DeviceMemory<float>* params_backprop_data,
2813     DeviceMemory<uint8>* reserve_space_data,
2814     ScratchAllocator* workspace_allocator,
2815     dnn::ProfileResult* output_profile_result) {
2816   // ROCM TODO: output_profile_result is ignore for now
2817 
2818   const MIOpenRnnDescriptor& miopen_rnn_desc =
2819       static_cast<const MIOpenRnnDescriptor&>(rnn_desc);
2820   const MIOpenRnnSequenceTensorDescriptor& miopen_input_desc =
2821       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(input_desc);
2822   const MIOpenRnnStateTensorDescriptor& miopen_input_h_desc =
2823       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_h_desc);
2824   const MIOpenRnnStateTensorDescriptor& miopen_input_c_desc =
2825       static_cast<const MIOpenRnnStateTensorDescriptor&>(input_c_desc);
2826   const MIOpenRnnSequenceTensorDescriptor& miopen_output_desc =
2827       static_cast<const MIOpenRnnSequenceTensorDescriptor&>(output_desc);
2828   const MIOpenRnnStateTensorDescriptor& miopen_output_h_desc =
2829       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_h_desc);
2830   const MIOpenRnnStateTensorDescriptor& miopen_output_c_desc =
2831       static_cast<const MIOpenRnnStateTensorDescriptor&>(output_c_desc);
2832 
2833   return DoRnnBackwardImpl<float>(
2834       stream, miopen_rnn_desc, miopen_input_desc, input_data,
2835       miopen_input_h_desc, input_h_data, miopen_input_c_desc, input_c_data,
2836       params, miopen_output_desc, output_data, miopen_output_h_desc,
2837       output_h_data, miopen_output_c_desc, output_c_data, output_backprop_data,
2838       output_h_backprop_data, output_c_backprop_data, input_backprop_data,
2839       input_h_backprop_data, input_c_backprop_data, params_backprop_data,
2840       reserve_space_data, workspace_allocator);
2841 }
2842 
DoRnnBackward(Stream * stream,const dnn::RnnDescriptor & rnn_desc,const dnn::RnnSequenceTensorDescriptor & input_desc,const DeviceMemory<double> & input_data,const DeviceMemory<int> & seq_lengths_data,const dnn::RnnStateTensorDescriptor & input_h_desc,const DeviceMemory<double> & input_h_data,const dnn::RnnStateTensorDescriptor & input_c_desc,const DeviceMemory<double> & input_c_data,const DeviceMemory<double> & params,const dnn::RnnSequenceTensorDescriptor & output_desc,const DeviceMemory<double> & output_data,const dnn::RnnStateTensorDescriptor & output_h_desc,const DeviceMemory<double> & output_h_data,const dnn::RnnStateTensorDescriptor & output_c_desc,const DeviceMemory<double> & output_c_data,const DeviceMemory<double> & output_backprop_data,const DeviceMemory<double> & output_h_backprop_data,const DeviceMemory<double> & output_c_backprop_data,DeviceMemory<double> * input_backprop_data,DeviceMemory<double> * input_h_backprop_data,DeviceMemory<double> * input_c_backprop_data,DeviceMemory<double> * params_backprop_data,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator,dnn::ProfileResult * output_profile_result)2843 bool MIOpenSupport::DoRnnBackward(
2844     Stream* stream, const dnn::RnnDescriptor& rnn_desc,
2845     const dnn::RnnSequenceTensorDescriptor& input_desc,
2846     const DeviceMemory<double>& input_data,
2847     const DeviceMemory<int>& seq_lengths_data,
2848     const dnn::RnnStateTensorDescriptor& input_h_desc,
2849     const DeviceMemory<double>& input_h_data,
2850     const dnn::RnnStateTensorDescriptor& input_c_desc,
2851     const DeviceMemory<double>& input_c_data,
2852     const DeviceMemory<double>& params,
2853     const dnn::RnnSequenceTensorDescriptor& output_desc,
2854     const DeviceMemory<double>& output_data,
2855     const dnn::RnnStateTensorDescriptor& output_h_desc,
2856     const DeviceMemory<double>& output_h_data,
2857     const dnn::RnnStateTensorDescriptor& output_c_desc,
2858     const DeviceMemory<double>& output_c_data,
2859     const DeviceMemory<double>& output_backprop_data,
2860     const DeviceMemory<double>& output_h_backprop_data,
2861     const DeviceMemory<double>& output_c_backprop_data,
2862     DeviceMemory<double>* input_backprop_data,
2863     DeviceMemory<double>* input_h_backprop_data,
2864     DeviceMemory<double>* input_c_backprop_data,
2865     DeviceMemory<double>* params_backprop_data,
2866     DeviceMemory<uint8>* reserve_space_data,
2867     ScratchAllocator* workspace_allocator,
2868     dnn::ProfileResult* output_profile_result) {
2869   LOG(ERROR) << "miopen does not support half type RNN bwd yet";
2870   return false;
2871 }
2872 
2873 // This is the context required to use the TF scratch allocator:
2874 struct MIOpenAllocatorContext {
MIOpenAllocatorContextstream_executor::gpu::MIOpenAllocatorContext2875   MIOpenAllocatorContext(ScratchAllocator* scratch_allocator, Stream* stream)
2876       : scratch_allocator_(scratch_allocator), stream_(stream) {}
2877 
2878   ScratchAllocator* scratch_allocator_;
2879   Stream* stream_;
2880 };
2881 
MIOpenAllocatorCallback(void * ctx,size_t size_in_bytes)2882 void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) {
2883   auto* mac = static_cast<MIOpenAllocatorContext*>(ctx);
2884   auto allocated = mac->scratch_allocator_->AllocateBytes(size_in_bytes);
2885 
2886   DeviceMemory<uint8> scratch;
2887   if (allocated.ok()) {
2888     scratch = allocated.ValueOrDie();
2889     return scratch.opaque();
2890   } else {
2891     return nullptr;
2892   }
2893 }
2894 
MIOpenDeallocatorCallback(void * ctx,void * mem)2895 void MIOpenDeallocatorCallback(void* ctx, void* mem) {
2896   // Don't need deallocator since the TensorFlow heap will automatically
2897   // reclaim the memory
2898 }
2899 
DoPrepareForConvolution(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::AlgorithmConfig & algorithm_config,ScratchAllocator * scratch_allocator,dnn::AlgorithmDesc * algorithm_desc,DeviceMemory<uint8> * scratch_memory)2900 port::Status MIOpenSupport::DoPrepareForConvolution(
2901     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
2902     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2903     const dnn::FilterDescriptor& filter_descriptor,
2904     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2905     DeviceMemoryBase output_data,
2906     const dnn::ConvolutionDescriptor& convolution_descriptor,
2907     const dnn::AlgorithmConfig& algorithm_config,
2908     ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc,
2909     DeviceMemory<uint8>* scratch_memory) {
2910   absl::optional<dnn::AlgorithmDesc> input_algo_desc =
2911       algorithm_config.algorithm();
2912 
2913   assert(input_algo_desc.has_value());
2914 
2915   // An algorithm has been specified.
2916   *algorithm_desc = *input_algo_desc;
2917 
2918   assert(algorithm_config.scratch_size().has_value());
2919 
2920   size_t scratch_memory_size = *(algorithm_config.scratch_size());
2921 
2922   // allocate scratch memory
2923   if (scratch_memory_size != 0) {
2924     if (scratch_allocator == nullptr) {
2925       return port::InternalError(
2926           absl::StrCat("An allocator must be specified when scratch memory is "
2927                        "needed"));
2928     }
2929     auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
2930     if (allocated.ok()) {
2931       *scratch_memory = allocated.ValueOrDie();
2932     } else {
2933       LOG(ERROR)
2934           << "Failed to allocate scratch memory - "
2935           << allocated.status().error_message() << "\n"
2936           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
2937              "larger number (e.g. 8192) to increase the max memory limit.\n"
2938           << "\tIncreasing the max memory limit might help resolve this "
2939              "error";
2940       return port::InternalError(absl::StrCat(
2941           "Failed to allocate scratch memory of size: ", scratch_memory_size));
2942     }
2943   }
2944 
2945   return port::Status::OK();
2946 }
2947 
DoConvolve(dnn::ConvolutionKind kind,dnn::DataType element_type,dnn::DataType output_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::AlgorithmDesc algorithm_desc,DeviceMemory<uint8> scratch_memory,dnn::ProfileResult * output_profile_result)2948 port::Status MIOpenSupport::DoConvolve(
2949     dnn::ConvolutionKind kind, dnn::DataType element_type,
2950     dnn::DataType output_type, Stream* stream,
2951     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
2952     const dnn::FilterDescriptor& filter_descriptor,
2953     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
2954     DeviceMemoryBase output_data,
2955     const dnn::ConvolutionDescriptor& convolution_descriptor,
2956     dnn::AlgorithmDesc algorithm_desc, DeviceMemory<uint8> scratch_memory,
2957     dnn::ProfileResult* output_profile_result) {
2958   auto miopen = miopen_->GetHandle(parent_, stream);
2959   ScopedTensorDescriptor input_nd{input_descriptor,
2960                                   ToMIOpenDataType(element_type)};
2961   ScopedTensorDescriptor output_nd{output_descriptor,
2962                                    ToMIOpenDataType(element_type)};
2963   ScopedFilterDescriptor filter{filter_descriptor,
2964                                 ToMIOpenDataType(element_type)};
2965   ScopedConvolutionDescriptor conv{convolution_descriptor,
2966                                    ToMIOpenDataType(element_type)};
2967 
2968   // Alpha is the scaling factor for input.
2969   float alpha = 1.0;
2970   // Beta is the scaling factor for output.
2971   float beta = 0.0;
2972 
2973   const bool is_profiling = output_profile_result != nullptr;
2974 
2975   std::unique_ptr<GpuTimer> timer;
2976   if (is_profiling) {
2977     timer.reset(new GpuTimer(parent_));
2978     if (!timer->Init()) {
2979       return port::Status(port::error::INTERNAL, "Failed to init timer");
2980     }
2981     // The start and stop of the timer should be as close to the MIOpen call
2982     // as possible. It is still possible for other threads to issue workload
2983     // on to this stream. So it could take multiple profiling measurements.
2984     if (!timer->Start(AsGpuStream(stream))) {
2985       timer->Destroy();
2986       return port::Status(port::error::INTERNAL, "Failed to start timer");
2987     }
2988   }
2989 
2990   miopenStatus_t status = miopenStatusSuccess;
2991   switch (kind) {
2992     case dnn::ConvolutionKind::FORWARD: {
2993       if (use_immediate_mode_) {
2994         status = wrap::miopenConvolutionForwardImmediate(
2995             miopen.handle(), filter.handle(), filter_data.opaque(),
2996             input_nd.handle(), input_data.opaque(), conv.handle(),
2997             output_nd.handle(), output_data.opaque(), scratch_memory.opaque(),
2998             scratch_memory.size(),
2999             static_cast<uint64_t>(algorithm_desc.algo_id()));
3000       } else {
3001         status = wrap::miopenConvolutionForward(
3002             miopen.handle(), &alpha, input_nd.handle(), input_data.opaque(),
3003             filter.handle(), filter_data.opaque(), conv.handle(),
3004             static_cast<miopenConvFwdAlgorithm_t>(algorithm_desc.algo_id()),
3005             &beta, output_nd.handle(), output_data.opaque(),
3006             scratch_memory.opaque(), scratch_memory.size());
3007       }
3008 
3009       break;
3010     }
3011     case dnn::ConvolutionKind::BACKWARD_DATA: {
3012       if (use_immediate_mode_) {
3013         status = wrap::miopenConvolutionBackwardDataImmediate(
3014             miopen.handle(), output_nd.handle(), output_data.opaque(),
3015             filter.handle(), filter_data.opaque(), conv.handle(),
3016             input_nd.handle(), input_data.opaque(), scratch_memory.opaque(),
3017             scratch_memory.size(),
3018             static_cast<uint64_t>(algorithm_desc.algo_id()));
3019       } else {
3020         status = wrap::miopenConvolutionBackwardData(
3021             miopen.handle(), &alpha, output_nd.handle(), output_data.opaque(),
3022             filter.handle(), filter_data.opaque(), conv.handle(),
3023             static_cast<miopenConvBwdDataAlgorithm_t>(algorithm_desc.algo_id()),
3024             &beta, input_nd.handle(), input_data.opaque(),
3025             scratch_memory.opaque(), scratch_memory.size());
3026       }
3027       break;
3028     }
3029     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3030       if (use_immediate_mode_) {
3031         status = wrap::miopenConvolutionBackwardWeightsImmediate(
3032             miopen.handle(), output_nd.handle(), output_data.opaque(),
3033             input_nd.handle(), input_data.opaque(), conv.handle(),
3034             filter.handle(), filter_data.opaque(), scratch_memory.opaque(),
3035             scratch_memory.size(),
3036             static_cast<uint64_t>(algorithm_desc.algo_id()));
3037       } else {
3038         status = wrap::miopenConvolutionBackwardWeights(
3039             miopen.handle(), &alpha, output_nd.handle(), output_data.opaque(),
3040             input_nd.handle(), input_data.opaque(), conv.handle(),
3041             static_cast<miopenConvBwdWeightsAlgorithm_t>(
3042                 algorithm_desc.algo_id()),
3043             &beta, filter.handle(), filter_data.opaque(),
3044             scratch_memory.opaque(), scratch_memory.size());
3045       }
3046       break;
3047     }
3048     default:
3049       return port::InternalError(
3050           absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
3051   }
3052 
3053   if (is_profiling) {
3054     if (!timer->Stop(AsGpuStream(stream))) {
3055       timer->Destroy();
3056       return port::Status(port::error::INTERNAL, "Failed to stop timer");
3057     }
3058     if (status == miopenStatusSuccess) {
3059       dnn::AlgorithmDesc algotype(algorithm_desc.algo_id(), false);
3060       output_profile_result->set_algorithm(algotype);
3061       output_profile_result->set_elapsed_time_in_ms(
3062           timer->GetElapsedMilliseconds());
3063       output_profile_result->set_scratch_size(scratch_memory.size());
3064     }
3065     timer->Destroy();
3066   }
3067 
3068   if (status != miopenStatusSuccess) {
3069     return port::InternalError(absl::StrCat(
3070         "Failed to euqueue convolution on stream: ", ToString(status)));
3071   }
3072 
3073   return port::Status::OK();
3074 }
3075 
GetConvolveAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)3076 bool MIOpenSupport::GetConvolveAlgorithms(
3077     // ROCM TODO: refactor cc_major / cc_minor
3078     CudaComputeCapability cuda_compute_capability,
3079     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3080   out_algorithms->assign({
3081       // clang-format off
3082       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoGEMM, false),
3083       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoDirect, false),
3084       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoFFT, false),
3085       dnn::AlgorithmDesc(miopenConvolutionFwdAlgoWinograd, false),
3086       // clang-format on
3087   });
3088   return true;
3089 }
3090 
GetMIOpenConvolveAlgorithms(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3091 bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
3092     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3093     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3094     const dnn::FilterDescriptor& filter_descriptor,
3095     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3096     DeviceMemoryBase output_data,
3097     const dnn::ConvolutionDescriptor& convolution_descriptor,
3098     ScratchAllocator* scratch_allocator,
3099     std::vector<dnn::ProfileResult>* out_algorithms) {
3100   return use_immediate_mode_
3101              ? GetMIOpenConvolveAlgorithmsImmediateMode(
3102                    kind, element_type, stream, input_descriptor, input_data,
3103                    filter_descriptor, filter_data, output_descriptor,
3104                    output_data, convolution_descriptor, scratch_allocator,
3105                    out_algorithms)
3106              : GetMIOpenConvolveAlgorithmsFindMode(
3107                    kind, element_type, stream, input_descriptor, input_data,
3108                    filter_descriptor, filter_data, output_descriptor,
3109                    output_data, convolution_descriptor, scratch_allocator,
3110                    out_algorithms);
3111 }
3112 
GetMIOpenConvolveAlgorithmsImmediateMode(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3113 bool MIOpenSupport::GetMIOpenConvolveAlgorithmsImmediateMode(
3114     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3115     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3116     const dnn::FilterDescriptor& filter_descriptor,
3117     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3118     DeviceMemoryBase output_data,
3119     const dnn::ConvolutionDescriptor& convolution_descriptor,
3120     ScratchAllocator* scratch_allocator,
3121     std::vector<dnn::ProfileResult>* out_algorithms) {
3122   auto miopen = miopen_->GetHandle(parent_, stream);
3123 
3124   ScopedTensorDescriptor input_nd{input_descriptor,
3125                                   ToMIOpenDataType(element_type)};
3126   ScopedTensorDescriptor output_nd{output_descriptor,
3127                                    ToMIOpenDataType(element_type)};
3128   ScopedFilterDescriptor filter{filter_descriptor,
3129                                 ToMIOpenDataType(element_type)};
3130   ScopedConvolutionDescriptor conv{convolution_descriptor,
3131                                    ToMIOpenDataType(element_type)};
3132 
3133   // First determine the number of algorityhms available
3134   size_t maxSolutionCount = 0;
3135 
3136   switch (kind) {
3137     case dnn::ConvolutionKind::FORWARD: {
3138       auto status = wrap::miopenConvolutionForwardGetSolutionCount(
3139           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3140           output_nd.handle(), &maxSolutionCount);
3141       if (status != miopenStatusSuccess) {
3142         LOG(FATAL)
3143             << "call to miopenConvolutionForwardGetSolutionCount failed: "
3144             << ToString(status);
3145         return false;
3146       }
3147       break;
3148     }
3149     case dnn::ConvolutionKind::BACKWARD_DATA: {
3150       auto status = wrap::miopenConvolutionBackwardDataGetSolutionCount(
3151           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3152           input_nd.handle(), &maxSolutionCount);
3153       if (status != miopenStatusSuccess) {
3154         LOG(FATAL) << "call to miopenConvolutionBackwardDataGetSolutionCount "
3155                       "failed: "
3156                    << ToString(status);
3157         return false;
3158       }
3159       break;
3160     }
3161     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3162       auto status = wrap::miopenConvolutionBackwardWeightsGetSolutionCount(
3163           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3164           filter.handle(), &maxSolutionCount);
3165       if (status != miopenStatusSuccess) {
3166         LOG(FATAL)
3167             << "call to miopenConvolutionBackwardWeightsGetSolutionCount "
3168                "failed: "
3169             << ToString(status);
3170         return false;
3171       }
3172       break;
3173     }
3174     default: {
3175       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3176       return false;
3177       break;
3178     }
3179   }
3180 
3181   VLOG(kConvDebugVlogLevel)
3182       << "Number of conv solutions max: " << maxSolutionCount;
3183 
3184   if (return_best_algo_only_) {
3185     VLOG(kConvDebugVlogLevel) << "TF_ROCM_RETURN_BEST_ALGO_ONLY is set, "
3186                               << "setting maxSolutionCount to 1";
3187     maxSolutionCount = 1;
3188   }
3189 
3190   size_t solutionCount = 0;
3191   std::unique_ptr<miopenConvSolution_t[]> solutions(
3192       new miopenConvSolution_t[maxSolutionCount]);
3193 
3194   switch (kind) {
3195     case dnn::ConvolutionKind::FORWARD: {
3196       auto status = wrap::miopenConvolutionForwardGetSolution(
3197           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3198           output_nd.handle(), maxSolutionCount, &solutionCount,
3199           solutions.get());
3200 
3201       if (status != miopenStatusSuccess) {
3202         LOG(FATAL) << "call to miopenConvolutionForwardGetSolution failed: "
3203                    << ToString(status);
3204         return false;
3205       }
3206 
3207       VLOG(kConvDebugVlogLevel)
3208           << "Number of conv solutions actual: " << solutionCount;
3209 
3210       for (size_t i = 0; i < solutionCount; i++) {
3211         miopenConvSolution_t solution = solutions[i];
3212 
3213         VLOG(kConvDebugVlogLevel)
3214             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3215             << ", " << solution.workspace_size << ", " << solution.solution_id
3216             << ", " << ToString(solution.algorithm);
3217 
3218         status = wrap::miopenConvolutionForwardCompileSolution(
3219             miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3220             output_nd.handle(), solution.solution_id);
3221 
3222         if (status != miopenStatusSuccess) {
3223           LOG(FATAL)
3224               << "call to miopenConvolutionForwardCompileSolution failed: "
3225               << ToString(status);
3226           return false;
3227         }
3228 
3229         out_algorithms->emplace_back(
3230             GetProfileResultFromConvSolution(solution));
3231       }
3232       break;
3233     }
3234 
3235     case dnn::ConvolutionKind::BACKWARD_DATA: {
3236       auto status = wrap::miopenConvolutionBackwardDataGetSolution(
3237           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3238           input_nd.handle(), maxSolutionCount, &solutionCount, solutions.get());
3239       if (status != miopenStatusSuccess) {
3240         LOG(FATAL)
3241             << "call to miopenConvolutionBackwardDataGetSolution failed: "
3242             << ToString(status);
3243         return false;
3244       }
3245 
3246       VLOG(kConvDebugVlogLevel)
3247           << "Number of conv solutions actual: " << solutionCount;
3248 
3249       for (size_t i = 0; i < solutionCount; i++) {
3250         miopenConvSolution_t solution = solutions[i];
3251 
3252         VLOG(kConvDebugVlogLevel)
3253             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3254             << ", " << solution.workspace_size << ", " << solution.solution_id
3255             << ", " << ToString(solution.algorithm);
3256 
3257         status = wrap::miopenConvolutionBackwardDataCompileSolution(
3258             miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3259             input_nd.handle(), solution.solution_id);
3260 
3261         if (status != miopenStatusSuccess) {
3262           LOG(FATAL) << " call to miopenConvolutionBackwardDataCompileSolution "
3263                         "failed: "
3264                      << ToString(status);
3265           return false;
3266         }
3267 
3268         out_algorithms->emplace_back(
3269             GetProfileResultFromConvSolution(solution));
3270       }
3271       break;
3272     }
3273     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3274       auto status = wrap::miopenConvolutionBackwardWeightsGetSolution(
3275           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3276           filter.handle(), maxSolutionCount, &solutionCount, solutions.get());
3277       if (status != miopenStatusSuccess) {
3278         LOG(FATAL)
3279             << "call to miopenConvolutionBackwardWeightsGetSolution failed: "
3280             << ToString(status);
3281         return false;
3282       }
3283 
3284       VLOG(kConvDebugVlogLevel)
3285           << "Number of conv solutions actual: " << solutionCount;
3286 
3287       for (size_t i = 0; i < solutionCount; i++) {
3288         miopenConvSolution_t solution = solutions[i];
3289 
3290         VLOG(kConvDebugVlogLevel)
3291             << "solution " << i << " (time, mem, id, algo) =  " << solution.time
3292             << ", " << solution.workspace_size << ", " << solution.solution_id
3293             << ", " << ToString(solution.algorithm);
3294 
3295         status = wrap::miopenConvolutionBackwardWeightsCompileSolution(
3296             miopen.handle(), output_nd.handle(), input_nd.handle(),
3297             conv.handle(), filter.handle(), solution.solution_id);
3298 
3299         if (status != miopenStatusSuccess) {
3300           LOG(FATAL)
3301               << "call to miopenConvolutionBackwardWeightsCompileSolution "
3302                  "failed: "
3303               << ToString(status);
3304           return false;
3305         }
3306 
3307         out_algorithms->emplace_back(
3308             GetProfileResultFromConvSolution(solution));
3309       }
3310       break;
3311     }
3312     default: {
3313       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3314       return false;
3315       break;
3316     }
3317   }
3318 
3319   return true;
3320 }
3321 
GetMIOpenConvolveAlgorithmsFindMode(dnn::ConvolutionKind kind,dnn::DataType element_type,Stream * stream,const dnn::BatchDescriptor & input_descriptor,DeviceMemoryBase input_data,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,const dnn::ConvolutionDescriptor & convolution_descriptor,ScratchAllocator * scratch_allocator,std::vector<dnn::ProfileResult> * out_algorithms)3322 bool MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode(
3323     dnn::ConvolutionKind kind, dnn::DataType element_type, Stream* stream,
3324     const dnn::BatchDescriptor& input_descriptor, DeviceMemoryBase input_data,
3325     const dnn::FilterDescriptor& filter_descriptor,
3326     DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor,
3327     DeviceMemoryBase output_data,
3328     const dnn::ConvolutionDescriptor& convolution_descriptor,
3329     ScratchAllocator* scratch_allocator,
3330     std::vector<dnn::ProfileResult>* out_algorithms) {
3331   auto miopen = miopen_->GetHandle(parent_, stream);
3332 
3333   ScopedTensorDescriptor input_nd{input_descriptor,
3334                                   ToMIOpenDataType(element_type)};
3335   ScopedTensorDescriptor output_nd{output_descriptor,
3336                                    ToMIOpenDataType(element_type)};
3337   ScopedFilterDescriptor filter{filter_descriptor,
3338                                 ToMIOpenDataType(element_type)};
3339   ScopedConvolutionDescriptor conv{convolution_descriptor,
3340                                    ToMIOpenDataType(element_type)};
3341 
3342   // Determine the workspace memory size that will need by the call to Find
3343   size_t scratch_memory_size = 0;
3344   switch (kind) {
3345     case dnn::ConvolutionKind::FORWARD: {
3346       auto status = wrap::miopenConvolutionForwardGetWorkSpaceSize(
3347           miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
3348           output_nd.handle(), &scratch_memory_size);
3349       if (status != miopenStatusSuccess) {
3350         LOG(FATAL)
3351             << "call to miopenConvolutionForwardGetWorkspaceSize failed: "
3352             << ToString(status);
3353         return false;
3354       }
3355       break;
3356     }
3357     case dnn::ConvolutionKind::BACKWARD_DATA: {
3358       auto status = wrap::miopenConvolutionBackwardDataGetWorkSpaceSize(
3359           miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
3360           input_nd.handle(), &scratch_memory_size);
3361       if (status != miopenStatusSuccess) {
3362         LOG(FATAL)
3363             << "call to miopenConvolutionBackwardDataGetWorkspaceSize failed: "
3364             << ToString(status);
3365         return false;
3366       }
3367       break;
3368     }
3369     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3370       auto status = wrap::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
3371           miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
3372           filter.handle(), &scratch_memory_size);
3373       if (status != miopenStatusSuccess) {
3374         LOG(FATAL)
3375             << "call to miopenConvolutionBackwardWeightsGetWorkspaceSize "
3376                "failed: "
3377             << ToString(status);
3378         return false;
3379       }
3380       break;
3381     }
3382     default: {
3383       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3384       return false;
3385       break;
3386     }
3387   }
3388 
3389   // allocate scratch memory
3390   DeviceMemory<uint8> scratch_memory;
3391   if (scratch_memory_size != 0) {
3392     if (scratch_allocator == nullptr) {
3393       LOG(FATAL)
3394           << "An allocator must be specified when scratch memory is needed";
3395       return false;
3396     }
3397     auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
3398     if (allocated.ok()) {
3399       scratch_memory = allocated.ValueOrDie();
3400     } else {
3401       LOG(FATAL)
3402           << "Failed to allocate scratch memory - "
3403           << allocated.status().error_message() << "\n"
3404           << "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
3405              "larger number (e.g. 8192) to increase the max memory limit.\n"
3406           << "\tIncreasing the max memory limit might help resolve this "
3407              "error";
3408       return false;
3409     }
3410   }
3411 
3412   // Only get the best algorithm for Find Mode
3413   size_t requestedAlgorithmCount = 1;
3414 
3415   VLOG(kConvDebugVlogLevel)
3416       << "Number of conv algortihms to request: " << requestedAlgorithmCount;
3417 
3418   miopenConvAlgoPerf_t returnedAlgorithm;
3419 
3420   int returnedAlgorithmCount = 0;
3421   bool exhaustiveSearch = false;
3422 
3423   switch (kind) {
3424     case dnn::ConvolutionKind::FORWARD: {
3425       auto status = wrap::miopenFindConvolutionForwardAlgorithm(
3426           miopen.handle(), input_nd.handle(), input_data.opaque(),
3427           filter.handle(), filter_data.opaque(), conv.handle(),
3428           output_nd.handle(), output_data.opaque(), requestedAlgorithmCount,
3429           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3430           scratch_memory_size, exhaustiveSearch);
3431       if (status != miopenStatusSuccess) {
3432         LOG(FATAL) << "call to miopenFindConvolutionForwardAlgorithm failed: "
3433                    << ToString(status);
3434         return false;
3435       }
3436       break;
3437     }
3438     case dnn::ConvolutionKind::BACKWARD_DATA: {
3439       auto status = wrap::miopenFindConvolutionBackwardDataAlgorithm(
3440           miopen.handle(), output_nd.handle(), output_data.opaque(),
3441           filter.handle(), filter_data.opaque(), conv.handle(),
3442           input_nd.handle(), input_data.opaque(), requestedAlgorithmCount,
3443           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3444           scratch_memory_size, exhaustiveSearch);
3445       if (status != miopenStatusSuccess) {
3446         LOG(FATAL)
3447             << "call to miopenFindConvolutionBackwardDataAlgorithm failed: "
3448             << ToString(status);
3449         return false;
3450       }
3451       break;
3452     }
3453     case dnn::ConvolutionKind::BACKWARD_FILTER: {
3454       auto status = wrap::miopenFindConvolutionBackwardWeightsAlgorithm(
3455           miopen.handle(), output_nd.handle(), output_data.opaque(),
3456           input_nd.handle(), input_data.opaque(), conv.handle(),
3457           filter.handle(), filter_data.opaque(), requestedAlgorithmCount,
3458           &returnedAlgorithmCount, &returnedAlgorithm, scratch_memory.opaque(),
3459           scratch_memory_size, exhaustiveSearch);
3460       if (status != miopenStatusSuccess) {
3461         LOG(FATAL) << "call to miopenConvolutionBackwardWeightsAlgorithm "
3462                       "failed: "
3463                    << ToString(status);
3464         return false;
3465       }
3466       break;
3467     }
3468     default: {
3469       LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
3470       return false;
3471       break;
3472     }
3473   }
3474 
3475   out_algorithms->emplace_back(
3476       GetProfileResultFromConvAlgoPerf(kind, returnedAlgorithm));
3477 
3478   return true;
3479 }
3480 
GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> * out_algorithms)3481 bool MIOpenSupport::GetRnnAlgorithms(
3482     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3483   // ROCM TODO: implement this with proper MIOpen API
3484   return true;
3485 }
3486 
GetConvolveBackwardDataAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)3487 bool MIOpenSupport::GetConvolveBackwardDataAlgorithms(
3488     // ROCM TODO: refactor cc_major / cc_minor
3489     CudaComputeCapability cuda_compute_capability,
3490     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3491   out_algorithms->assign({
3492       // clang-format off
3493       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoGEMM, false),
3494       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoDirect, false),
3495       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoFFT, false),
3496       dnn::AlgorithmDesc(miopenConvolutionBwdDataAlgoWinograd, false),
3497       // clang-format on
3498   });
3499   return true;
3500 }
3501 
GetConvolveBackwardFilterAlgorithms(CudaComputeCapability cuda_compute_capability,std::vector<dnn::AlgorithmDesc> * out_algorithms)3502 bool MIOpenSupport::GetConvolveBackwardFilterAlgorithms(
3503     // ROCM TODO: refactor cc_major / cc_minor
3504     CudaComputeCapability cuda_compute_capability,
3505     std::vector<dnn::AlgorithmDesc>* out_algorithms) {
3506   out_algorithms->assign({
3507       // clang-format off
3508       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoGEMM, false),
3509       dnn::AlgorithmDesc(miopenConvolutionBwdWeightsAlgoDirect, false),
3510       // clang-format on
3511   });
3512   return true;
3513 }
3514 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<Eigen::half> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3515 bool MIOpenSupport::DoBatchNormalizationForward(
3516     Stream* stream, const DeviceMemory<Eigen::half>& x,
3517     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3518     const DeviceMemory<float>& estimated_mean,
3519     const DeviceMemory<float>& estimated_variance,
3520     const DeviceMemory<Eigen::half>& side_input,
3521     const dnn::BatchDescriptor& x_desc,
3522     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3523     const double exponential_average_factor,
3524     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y,
3525     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3526     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3527     bool is_training, ScratchAllocator* reserve_space_allocator,
3528     ScratchAllocator* workspace_allocator) {
3529   return DoBatchNormalizationForwardImpl<Eigen::half, float>(
3530       stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
3531       estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
3532       epsilon, exponential_average_factor, activation_mode, y, batch_mean,
3533       batch_var, saved_mean, saved_inv_var, is_training);
3534 }
3535 
DoBatchNormalizationForward(Stream * stream,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & offset,const DeviceMemory<float> & estimated_mean,const DeviceMemory<float> & estimated_variance,const DeviceMemory<float> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<float> * y,DeviceMemory<float> * batch_mean,DeviceMemory<float> * batch_var,DeviceMemory<float> * saved_mean,DeviceMemory<float> * saved_inv_var,bool is_training,ScratchAllocator * reserve_space_allocator,ScratchAllocator * workspace_allocator)3536 bool MIOpenSupport::DoBatchNormalizationForward(
3537     Stream* stream, const DeviceMemory<float>& x,
3538     const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
3539     const DeviceMemory<float>& estimated_mean,
3540     const DeviceMemory<float>& estimated_variance,
3541     const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
3542     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3543     const double exponential_average_factor,
3544     dnn::ActivationMode activation_mode, DeviceMemory<float>* y,
3545     DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
3546     DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
3547     bool is_training, ScratchAllocator* reserve_space_allocator,
3548     ScratchAllocator* workspace_allocator) {
3549   return DoBatchNormalizationForwardImpl<float, float>(
3550       stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
3551       estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
3552       epsilon, exponential_average_factor, activation_mode, y, batch_mean,
3553       batch_var, saved_mean, saved_inv_var, is_training);
3554 }
3555 
3556 template <class T, class U>
DoBatchNormalizationForwardImpl(Stream * stream,dnn::DataType input_data_type,dnn::DataType scale_data_type,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & offset,const DeviceMemory<U> & estimated_mean,const DeviceMemory<U> & estimated_variance,const DeviceMemory<T> & side_input,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,const double exponential_average_factor,dnn::ActivationMode activation_mode,DeviceMemory<T> * y,DeviceMemory<U> * batch_mean,DeviceMemory<U> * batch_var,DeviceMemory<U> * saved_mean,DeviceMemory<U> * saved_inv_var,bool is_training)3557 bool MIOpenSupport::DoBatchNormalizationForwardImpl(
3558     Stream* stream, dnn::DataType input_data_type,
3559     dnn::DataType scale_data_type, const DeviceMemory<T>& x,
3560     const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
3561     const DeviceMemory<U>& estimated_mean,
3562     const DeviceMemory<U>& estimated_variance,
3563     const DeviceMemory<T>& side_input, const dnn::BatchDescriptor& x_desc,
3564     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3565     const double exponential_average_factor,
3566     dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
3567     DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
3568     DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
3569     bool is_training) {
3570   auto miopen = miopen_->GetHandle(parent_, stream);
3571 
3572   ScopedTensorDescriptor x_descriptor{x_desc,
3573                                       ToMIOpenDataType(input_data_type)};
3574   ScopedTensorDescriptor scale_offset_descriptor{
3575       scale_offset_desc, ToMIOpenDataType(scale_data_type)};
3576   miopenBatchNormMode_t mode = miopenBNSpatial;
3577   float one = 1.0;
3578   float zero = 0.0;
3579 
3580   auto status = miopenStatusInvalidValue;
3581   if (is_training) {
3582     status = wrap::miopenBatchNormalizationForwardTraining(
3583         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3584         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3585         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3586         exponential_average_factor, batch_mean->opaque(), batch_var->opaque(),
3587         epsilon, saved_mean->opaque(), saved_inv_var->opaque());
3588   } else {
3589     const void* maybe_inv_var = estimated_variance.opaque();
3590     status = wrap::miopenBatchNormalizationForwardInference(
3591         miopen.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
3592         x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
3593         const_cast<void*>(scale.opaque()), const_cast<void*>(offset.opaque()),
3594         const_cast<void*>(estimated_mean.opaque()),
3595         const_cast<void*>(maybe_inv_var), epsilon);
3596   }
3597   if (status != miopenStatusSuccess) {
3598     LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
3599                << ToString(status);
3600     return false;
3601   }
3602   return true;
3603 }
3604 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<Eigen::half> & y_backprop,const DeviceMemory<Eigen::half> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & inv_var,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<Eigen::half> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3605 bool MIOpenSupport::DoBatchNormalizationBackward(
3606     Stream* stream, const DeviceMemory<Eigen::half>& y_backprop,
3607     const DeviceMemory<Eigen::half>& x, const DeviceMemory<float>& scale,
3608     const DeviceMemory<float>& mean, const DeviceMemory<float>& inv_var,
3609     const dnn::BatchDescriptor& x_desc,
3610     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3611     DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
3612     DeviceMemory<float>* offset_backprop,
3613     DeviceMemory<uint8>* reserve_space_data,
3614     ScratchAllocator* workspace_allocator) {
3615   return DoBatchNormalizationBackwardImpl<Eigen::half, float>(
3616       stream, miopenHalf, miopenFloat, y_backprop, x, scale, mean, inv_var,
3617       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3618       offset_backprop);
3619 }
3620 
DoBatchNormalizationBackward(Stream * stream,const DeviceMemory<float> & y_backprop,const DeviceMemory<float> & x,const DeviceMemory<float> & scale,const DeviceMemory<float> & mean,const DeviceMemory<float> & variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<float> * x_backprop,DeviceMemory<float> * scale_backprop,DeviceMemory<float> * offset_backprop,DeviceMemory<uint8> * reserve_space_data,ScratchAllocator * workspace_allocator)3621 bool MIOpenSupport::DoBatchNormalizationBackward(
3622     Stream* stream, const DeviceMemory<float>& y_backprop,
3623     const DeviceMemory<float>& x, const DeviceMemory<float>& scale,
3624     const DeviceMemory<float>& mean, const DeviceMemory<float>& variance,
3625     const dnn::BatchDescriptor& x_desc,
3626     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3627     DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
3628     DeviceMemory<float>* offset_backprop,
3629     DeviceMemory<uint8>* reserve_space_data,
3630     ScratchAllocator* workspace_allocator) {
3631   return DoBatchNormalizationBackwardImpl<float, float>(
3632       stream, miopenFloat, miopenFloat, y_backprop, x, scale, mean, variance,
3633       x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
3634       offset_backprop);
3635 }
3636 
3637 template <class T, class U>
DoBatchNormalizationBackwardImpl(Stream * stream,int miopen_input_type,int miopen_scale_type,const DeviceMemory<T> & y_backprop,const DeviceMemory<T> & x,const DeviceMemory<U> & scale,const DeviceMemory<U> & mean,const DeviceMemory<U> & variance,const dnn::BatchDescriptor & x_desc,const dnn::BatchDescriptor & scale_offset_desc,const double epsilon,DeviceMemory<T> * x_backprop,DeviceMemory<U> * scale_backprop,DeviceMemory<U> * offset_backprop)3638 bool MIOpenSupport::DoBatchNormalizationBackwardImpl(
3639     Stream* stream, int miopen_input_type, int miopen_scale_type,
3640     const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
3641     const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
3642     const DeviceMemory<U>& variance, const dnn::BatchDescriptor& x_desc,
3643     const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
3644     DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
3645     DeviceMemory<U>* offset_backprop) {
3646   auto miopen = miopen_->GetHandle(parent_, stream);
3647   ScopedTensorDescriptor x_descriptor{
3648       x_desc, static_cast<miopenDataType_t>(miopen_input_type)};
3649   ScopedTensorDescriptor scale_offset_descriptor{
3650       scale_offset_desc, static_cast<miopenDataType_t>(miopen_scale_type)};
3651   miopenBatchNormMode_t mode = miopenBNSpatial;
3652   float one = 1.0;
3653   float zero = 0.0;
3654 
3655   auto status = wrap::miopenBatchNormalizationBackward(
3656       miopen.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
3657       x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
3658       x_descriptor.handle(), x_backprop->opaque(),
3659       scale_offset_descriptor.handle(), scale.opaque(),
3660       scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
3661       mean.opaque(), variance.opaque());
3662   if (status != miopenStatusSuccess) {
3663     LOG(ERROR) << "failed to enqueue backward batch normalization on stream: "
3664                << ToString(status);
3665     return false;
3666   }
3667   return true;
3668 }
3669 
DoFusedConvolve(Stream * stream,dnn::DataType input_type,dnn::DataType side_input_type,dnn::DataType bias_type,dnn::DataType output_type,const dnn::BatchDescriptor & conv_input_descriptor,DeviceMemoryBase conv_input_data,double conv_input_scale,const dnn::FilterDescriptor & filter_descriptor,DeviceMemoryBase filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,DeviceMemoryBase side_input_data,double side_input_scale,const dnn::BatchDescriptor & bias_descriptor,DeviceMemoryBase biases,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemoryBase output_data,ScratchAllocator * scratch_allocator,const dnn::AlgorithmConfig & algorithm_config,dnn::ProfileResult * output_profile_result)3670 port::Status MIOpenSupport::DoFusedConvolve(
3671     Stream* stream, dnn::DataType input_type, dnn::DataType side_input_type,
3672     dnn::DataType bias_type, dnn::DataType output_type,
3673     const dnn::BatchDescriptor& conv_input_descriptor,
3674     DeviceMemoryBase conv_input_data, double conv_input_scale,
3675     const dnn::FilterDescriptor& filter_descriptor,
3676     DeviceMemoryBase filter_data,
3677     const dnn::ConvolutionDescriptor& convolution_descriptor,
3678     DeviceMemoryBase side_input_data, double side_input_scale,
3679     const dnn::BatchDescriptor& bias_descriptor, DeviceMemoryBase biases,
3680     dnn::ActivationMode activation_mode,
3681     const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data,
3682     ScratchAllocator* scratch_allocator,
3683     const dnn::AlgorithmConfig& algorithm_config,
3684     dnn::ProfileResult* output_profile_result) {
3685   return port::UnimplementedError("fused convolve not implemented yet");
3686 }
3687 
DoTransformTensor(Stream * stream,const dnn::BatchDescriptor & input_desc,dnn::DataType input_type,const DeviceMemoryBase & input_data,const dnn::BatchDescriptor & output_desc,dnn::DataType output_type,float scale,DeviceMemoryBase * output_data)3688 bool MIOpenSupport::DoTransformTensor(Stream* stream,
3689                                       const dnn::BatchDescriptor& input_desc,
3690                                       dnn::DataType input_type,
3691                                       const DeviceMemoryBase& input_data,
3692                                       const dnn::BatchDescriptor& output_desc,
3693                                       dnn::DataType output_type, float scale,
3694                                       DeviceMemoryBase* output_data) {
3695   // ROCM TODO implement this operation
3696   LOG(ERROR) << "transform tensor not implemented yet";
3697   return false;
3698 }
3699 
DoMatMul(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & weights,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)3700 bool MIOpenSupport::DoMatMul(Stream* stream,
3701                              const DeviceMemory<float>& input_data,
3702                              const DeviceMemory<float>& weights,
3703                              const dnn::BatchDescriptor& input_dimensions,
3704                              const dnn::BatchDescriptor& output_dimensions,
3705                              DeviceMemory<float>* output_data) {
3706   if (input_dimensions.count() != output_dimensions.count()) {
3707     LOG(ERROR) << "MatMul input and output dimensions are not compatible.";
3708     return false;
3709   }
3710 
3711   // We do not permute the input or output, instead we just
3712   // reinterpret the layout. We are working with row-major matrices
3713   // and the rows of the input and output correspond to batch, so
3714   // batch has to be outermost in both the input and output.
3715   //
3716   // By adding transposes to the BLAS gemm call we could perhaps make
3717   // the kYXDepthBatch layout work as well, but there has been no need
3718   // for that so far.
3719   if (input_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3720       input_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3721     LOG(ERROR) << "Unsupported MatMul input layout.";
3722     return false;
3723   }
3724   if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3725       output_dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
3726     LOG(ERROR) << "Unsupported MatMul output layout.";
3727     return false;
3728   }
3729 
3730   if (output_dimensions.width() == 1 && output_dimensions.height() == 1) {
3731     // This is a fast path that also supports the kBatchYXDepth layout.
3732 
3733     // The matrices here are in row-major format while BLAS expects
3734     // column-major, i.e. our matrices are transposed as far as BLAS
3735     // is concerned. So we need to compute output^T =
3736     // input^T*weights^T. There is no parameter for transposing the
3737     // output in BLAS gemm, but instead we can transpose both sides of
3738     // the equality to see that this is equivalent to
3739     // output=weights*input. So we only need to swap the order of
3740     // weights and input in the matrix product to correct for the
3741     // row-major versus column-major difference.
3742     const int64 m = output_dimensions.NodesAcrossFeatureMaps();
3743     const int64 n = input_dimensions.count();
3744     const int64 k = input_dimensions.NodesAcrossFeatureMaps();
3745     if (!stream
3746              ->ThenBlasGemm(blas::Transpose::kNoTranspose,
3747                             blas::Transpose::kNoTranspose, m, n, k, weights, m,
3748                             input_data, k, output_data, m)
3749              .ok()) {
3750       return false;
3751     }
3752   } else {
3753     // This is a slower and more complex path that supports output
3754     // width() * height() > 1, though it only supports the
3755     // kBatchYXDepth layout. Does support kBatchDepthYX if output
3756     // feature_map_count() == 1, as then there is no difference
3757     // between the two layouts.
3758     //
3759     // The operation here is the same as above, except that we have to
3760     // do the matrix multiplication for each (y,x) output coordinate
3761     // separately. We then interpret weights as containing K = width()
3762     // * height() different matrices, which we all multiply onto the
3763     // matrix from input_data, yielding K matrix products. We then
3764     // combine these together into one matrix by concatenating all the
3765     // first rows of these matrices, then all the seconds rows and so
3766     // on. We can do this with a batched matrix multiplication, where
3767     // the result is written to a different submatrix of the output
3768     // for each matrix multiplication.
3769     //
3770     // The reason that we only support the kBatchYXDepth output layout
3771     // is that we have to do something in the depth for each (y,x)
3772     // coordinate. The kBatchYXDepth layout has the depth information
3773     // for each point (y,x) in contiguous memory while the
3774     // kBatchDepthYX layout does not.
3775     //
3776     // TODO(broune): Consider a special case for when output depth ==
3777     // 1, as then possibly this could all be done as one matrix
3778     // multiplication instead of a batched one, which should be
3779     // faster. Another possibility would be to add a weights layout
3780     // parameter and then support kBatchDepthYX for a different
3781     // weights layout.
3782     if (output_dimensions.layout() != dnn::DataLayout::kBatchYXDepth &&
3783         !(output_dimensions.layout() == dnn::DataLayout::kBatchDepthYX &&
3784           output_dimensions.feature_map_count() == 1)) {
3785       LOG(ERROR) << "Unsupported MatMul output layout.";
3786       return false;
3787     }
3788 
3789     const float alpha = 1.0f;  // Take the matrix product without scaling it.
3790     const float beta = 0.0f;   // Ignore the original values in output_data.
3791     const uint64 m = output_dimensions.feature_map_count();
3792     const uint64 n = input_dimensions.count();
3793     const uint64 k = input_dimensions.NodesAcrossFeatureMaps();
3794     const int lda = m;
3795     const int ldb = k;
3796     const int ldc = output_dimensions.NodesAcrossFeatureMaps();
3797     const int batch_count = output_dimensions.NodesPerFeatureMap();
3798 
3799     std::vector<DeviceMemory<float>> a(batch_count);
3800     std::vector<DeviceMemory<float>> b(batch_count);
3801     std::vector<DeviceMemory<float>> c(batch_count);
3802     for (int i = 0; i < batch_count; ++i) {
3803       const int weights_offset = i * input_dimensions.NodesAcrossFeatureMaps() *
3804                                  output_dimensions.feature_map_count();
3805       a[i] = DeviceMemory<float>::MakeFromByteSize(
3806           const_cast<float*>(reinterpret_cast<const float*>(weights.opaque())) +
3807               weights_offset,
3808           weights.ElementCount() - weights_offset);
3809 
3810       b[i] = input_data;
3811 
3812       const int output_offset = i * output_dimensions.feature_map_count();
3813       c[i] = DeviceMemory<float>::MakeFromByteSize(
3814           const_cast<float*>(
3815               reinterpret_cast<const float*>(output_data->opaque())) +
3816               output_offset,
3817           output_data->ElementCount() - output_offset);
3818     }
3819     const auto toPtrs = [](std::vector<DeviceMemory<float>>& v) {
3820       std::vector<DeviceMemory<float>*> ptrs;
3821       ptrs.reserve(v.size());
3822       for (auto& mem : v) {
3823         ptrs.push_back(&mem);
3824       }
3825       return ptrs;
3826     };
3827 
3828     stream->ThenBlasGemmBatched(blas::Transpose::kNoTranspose,
3829                                 blas::Transpose::kNoTranspose, m, n, k, alpha,
3830                                 toPtrs(a), lda, toPtrs(b), ldb, beta, toPtrs(c),
3831                                 ldc, batch_count);
3832   }
3833 
3834   return stream->ok();
3835 }
3836 
DoBiasAdd(Stream * stream,const DeviceMemory<float> & input_data,const DeviceMemory<float> & biases,const dnn::BatchDescriptor & dimensions,DeviceMemory<float> * output_data)3837 bool MIOpenSupport::DoBiasAdd(Stream* stream,
3838                               const DeviceMemory<float>& input_data,
3839                               const DeviceMemory<float>& biases,
3840                               const dnn::BatchDescriptor& dimensions,
3841                               DeviceMemory<float>* output_data) {
3842   ScopedTensorDescriptor input_descriptor{dimensions, miopenFloat};
3843 
3844   BatchDescriptor bias_dimensions;
3845   bias_dimensions.set_count(1)
3846       .set_feature_map_count(dimensions.feature_map_count())
3847       .set_height(1)
3848       .set_width(1)
3849       .set_layout(dnn::DataLayout::kBatchYXDepth);
3850   ScopedTensorDescriptor bias_descriptor{bias_dimensions, miopenFloat};
3851 
3852   if (input_data.opaque() != output_data->opaque()) {
3853     stream->ThenMemcpy(output_data, input_data,
3854                        dimensions.ElementCount() * sizeof(float));
3855     if (!stream->ok()) {
3856       LOG(ERROR)
3857           << "stream " << stream
3858           << " could not enqueue a tensor copy as part of bias addition.";
3859       return false;
3860     }
3861   }
3862 
3863   auto miopen = miopen_->GetHandle(parent_, stream);
3864 
3865   const float alpha1 = 1.0f;
3866   const float alpha2 = 0.0f;
3867   const float beta = 1.0f;
3868 
3869   auto status = wrap::miopenOpTensor(
3870       miopen.handle(), miopenTensorOpAdd, &alpha1, bias_descriptor.handle(),
3871       biases.opaque(), &alpha2, bias_descriptor.handle(), biases.opaque(),
3872       &beta, input_descriptor.handle(), output_data->opaque());
3873 
3874   if (status != miopenStatusSuccess) {
3875     LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
3876     return false;
3877   }
3878 
3879   return true;
3880 }
3881 
DoActivate(Stream * stream,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data,uint64 options)3882 bool MIOpenSupport::DoActivate(Stream* stream,
3883                                dnn::ActivationMode activation_mode,
3884                                const dnn::BatchDescriptor& dimensions,
3885                                const DeviceMemory<float>& input_data,
3886                                DeviceMemory<float>* output_data,
3887                                uint64 options) {
3888   LOG(ERROR) << "miopen does not support activation yet";
3889   return false;
3890 }
3891 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<double> * output_data,ScratchAllocator * workspace_allocator)3892 bool MIOpenSupport::DoPoolForward(
3893     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3894     const dnn::BatchDescriptor& input_dimensions,
3895     const DeviceMemory<double>& input_data,
3896     const dnn::BatchDescriptor& output_dimensions,
3897     DeviceMemory<double>* output_data, ScratchAllocator* workspace_allocator) {
3898   LOG(ERROR) << "miopen does not support pooling for double type yet";
3899   return false;
3900 }
3901 
IsSame(const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type)3902 bool PoolingWorkspaceDescriptor::IsSame(
3903     const dnn::BatchDescriptor& input_dimensions,
3904     const dnn::BatchDescriptor& output_dimensions,
3905     const dnn::PoolingDescriptor& pooling_dimensions, int _type) {
3906   return dtype == _type &&
3907          input_dims ==
3908              input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
3909          output_dims ==
3910              output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX) &&
3911          op.mode() == pooling_dimensions.mode() &&
3912          op.window() == pooling_dimensions.window() &&
3913          op.padding() == pooling_dimensions.padding() &&
3914          op.strides() == pooling_dimensions.strides();
3915 }
3916 
find(const void * p,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type,PoolingWorkspaceDescriptor * & pdesc)3917 bool PoolingWorkspaceCache::find(
3918     const void* p, const dnn::BatchDescriptor& input_dimensions,
3919     const dnn::BatchDescriptor& output_dimensions,
3920     const dnn::PoolingDescriptor& pooling_dimensions, int _type,
3921     PoolingWorkspaceDescriptor*& pdesc) {
3922   pdesc = 0;
3923   auto it = cache.find(p);
3924   if (it == cache.end()) {
3925     return false;
3926   }
3927   if (!it->second.IsSame(input_dimensions, output_dimensions,
3928                          pooling_dimensions, _type)) {
3929     return false;
3930   }
3931   pdesc = &it->second;
3932   return true;
3933 }
3934 
insert(const void * p,const dnn::BatchDescriptor & input_dimensions,const dnn::BatchDescriptor & output_dimensions,const dnn::PoolingDescriptor & pooling_dimensions,int _type,std::unique_ptr<TemporaryDeviceMemory<uint8>> & workspace,size_t wsp_size,hipStream_t hip_stream)3935 void PoolingWorkspaceCache::insert(
3936     const void* p, const dnn::BatchDescriptor& input_dimensions,
3937     const dnn::BatchDescriptor& output_dimensions,
3938     const dnn::PoolingDescriptor& pooling_dimensions, int _type,
3939     std::unique_ptr<TemporaryDeviceMemory<uint8>>& workspace, size_t wsp_size,
3940     hipStream_t hip_stream) {
3941   PoolingWorkspaceDescriptor* desc = 0;
3942   auto it = cache.find(p);
3943   if (it != cache.end()) {
3944     // replacing an entry with the same pointer but different attributes
3945     // (if everything matches, the caller is expected to reuse the entry)
3946     desc = &it->second;
3947     hipStreamSynchronize(hip_stream);
3948     memory_used -= desc->workspace_size;
3949   } else {
3950     cache[p] = PoolingWorkspaceDescriptor();
3951     desc = &cache[p];
3952   }
3953   desc->input_dims = input_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
3954   desc->output_dims =
3955       output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
3956   desc->op = pooling_dimensions;
3957   desc->dtype = _type;
3958   desc->timestamp = timestamp;
3959   timestamp++;
3960   desc->workspace = std::move(workspace);
3961   desc->workspace_size = wsp_size;
3962   memory_used += wsp_size;
3963   trim(hip_stream);
3964 }
3965 
trim(hipStream_t hip_stream)3966 void PoolingWorkspaceCache::trim(hipStream_t hip_stream) {
3967   if (memory_used < memory_budget && cache.size() < trim_size) return;
3968   bool must_sync = true;
3969   while (true) {
3970     int new_size = cache.size() - (cache.size() >> 2);
3971     std::vector<const void*> old_entries;
3972     for (auto& x : cache)
3973       if (x.second.timestamp + new_size < timestamp)
3974         old_entries.push_back(x.first);
3975     if (old_entries.empty()) break;
3976     if (must_sync) hipStreamSynchronize(hip_stream);
3977     must_sync = true;
3978     for (auto x : old_entries) {
3979       memory_used -= cache[x].workspace_size;
3980       cache.erase(x);
3981     }
3982     if (memory_used < memory_budget || cache.size() < 10) break;
3983   }
3984 }
3985 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data,ScratchAllocator * workspace_allocator)3986 bool MIOpenSupport::DoPoolForward(
3987     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
3988     const dnn::BatchDescriptor& input_dimensions,
3989     const DeviceMemory<float>& input_data,
3990     const dnn::BatchDescriptor& output_dimensions,
3991     DeviceMemory<float>* output_data, ScratchAllocator* workspace_allocator) {
3992   auto miopen = miopen_->GetHandle(parent_, stream);
3993   // Alpha is the scaling factor for input.
3994   float alpha = 1.0;
3995   // Beta is the scaling factor for output.
3996   float beta = 0.0;
3997 
3998   ScopedTensorDescriptor src_desc{input_dimensions, miopenFloat};
3999   ScopedTensorDescriptor dest_desc{output_dimensions, miopenFloat};
4000   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4001 
4002   bool do_backward = false;
4003   uint8* workspace = 0;
4004   size_t workspace_size = 0;
4005   std::unique_ptr<TemporaryDeviceMemory<uint8>> wsp_mem;
4006   if (m_pooling_cache_enabled) {
4007     do_backward = true;
4008     auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
4009         pooling_desc.handle(), dest_desc.handle(), &workspace_size);
4010     if (status != miopenStatusSuccess) {
4011       LOG(ERROR)
4012           << "failed to obtain workspace size for backward pooling on stream: "
4013           << ToString(status);
4014       return false;
4015     }
4016     if (workspace_size != 0) {
4017       PoolingWorkspaceDescriptor* pdesc = 0;
4018       bool cache_hit =
4019           m_pooling_cache_allowed &&
4020           m_pooling_cache.find(input_data.opaque(), input_dimensions,
4021                                output_dimensions, pooling_dimensions,
4022                                miopenFloat, pdesc);
4023       if (cache_hit) {
4024         // reusing the same buffer
4025         workspace = reinterpret_cast<uint8*>(
4026             pdesc->workspace->mutable_device_memory()->opaque());
4027       } else {
4028         wsp_mem = stream->AllocateTemporaryArray<uint8>(workspace_size)
4029                       .ConsumeValueOrDie();
4030         workspace = reinterpret_cast<uint8*>(
4031             wsp_mem->mutable_device_memory()->opaque());
4032         m_pooling_cache.insert(input_data.opaque(), input_dimensions,
4033                                output_dimensions, pooling_dimensions,
4034                                miopenFloat, wsp_mem, workspace_size,
4035                                AsGpuStreamValue(stream));
4036       }
4037     }
4038   }
4039 
4040   auto status = wrap::miopenPoolingForward(
4041       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4042       input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
4043       do_backward, workspace, workspace_size);
4044   if (status != miopenStatusSuccess) {
4045     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
4046                << ToString(status);
4047     return false;
4048   }
4049   return true;
4050 }
4051 
DoPoolForward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<Eigen::half> * output_data,ScratchAllocator * workspace_allocator)4052 bool MIOpenSupport::DoPoolForward(
4053     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4054     const dnn::BatchDescriptor& input_dimensions,
4055     const DeviceMemory<Eigen::half>& input_data,
4056     const dnn::BatchDescriptor& output_dimensions,
4057     DeviceMemory<Eigen::half>* output_data,
4058     ScratchAllocator* workspace_allocator) {
4059   auto miopen = miopen_->GetHandle(parent_, stream);
4060 
4061   // Alpha is the scaling factor for input.
4062   float alpha = 1.0;
4063   // Beta is the scaling factor for output.
4064   float beta = 0.0;
4065 
4066   ScopedTensorDescriptor src_desc{input_dimensions, miopenHalf};
4067   ScopedTensorDescriptor dest_desc{output_dimensions, miopenHalf};
4068   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4069 
4070   auto status = wrap::miopenPoolingForward(
4071       miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4072       input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque(),
4073       false, nullptr, 0);
4074   if (status != miopenStatusSuccess) {
4075     LOG(ERROR) << "failed to enqueue forward pooling on stream: "
4076                << ToString(status);
4077     return false;
4078   }
4079   return true;
4080 }
4081 
4082 template <class T>
DoPoolBackwardImpl(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<T> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<T> & output_data,const DeviceMemory<T> & input_diff_data,DeviceMemory<T> * output_diff_data,ScratchAllocator * workspace_allocator)4083 bool MIOpenSupport::DoPoolBackwardImpl(
4084     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4085     const dnn::BatchDescriptor& input_dimensions,
4086     const DeviceMemory<T>& input_data,
4087     const dnn::BatchDescriptor& output_dimensions,
4088     const DeviceMemory<T>& output_data, const DeviceMemory<T>& input_diff_data,
4089     DeviceMemory<T>* output_diff_data, ScratchAllocator* workspace_allocator) {
4090   auto miopen = miopen_->GetHandle(parent_, stream);
4091   if (m_pooling_cache_allowed) m_pooling_cache_enabled = true;
4092   // Alpha is the scaling factor for input.
4093   float alpha = 1.0;
4094   // Beta is the scaling factor for output.
4095   float beta = 0.0;
4096 
4097   auto type =
4098       std::is_same<T, float>::value
4099           ? miopenFloat
4100           : (std::is_same<T, Eigen::half>::value ? miopenHalf
4101                                                  : (miopenDataType_t)-1);
4102 
4103   ScopedTensorDescriptor src_desc{input_dimensions, type};
4104   ScopedTensorDescriptor dest_desc{output_dimensions, type};
4105   ScopedPoolingDescriptor pooling_desc{pooling_dimensions};
4106 
4107   uint8* workspace_ptr = 0;
4108   DeviceMemory<uint8> workspace;
4109   PoolingWorkspaceDescriptor* pdesc = 0;
4110 
4111   size_t workspace_size_in_bytes = 0;
4112   auto status = wrap::miopenPoolingGetWorkSpaceSizeV2(
4113       pooling_desc.handle(), dest_desc.handle(), &workspace_size_in_bytes);
4114   if (status != miopenStatusSuccess) {
4115     LOG(ERROR)
4116         << "failed to obtain workspace size for backward pooling on stream: "
4117         << ToString(status);
4118     return false;
4119   }
4120 
4121   // Allocate the workspace.
4122   if (workspace_size_in_bytes > 0) {
4123     bool cache_hit = m_pooling_cache_allowed &&
4124                      m_pooling_cache.find(input_data.opaque(), input_dimensions,
4125                                           output_dimensions, pooling_dimensions,
4126                                           type, pdesc);
4127     if (cache_hit) {
4128       assert(pdesc != 0);
4129       workspace_ptr = reinterpret_cast<uint8*>(
4130           pdesc->workspace->mutable_device_memory()->opaque());
4131       VLOG(1) << "Pooling cache hit";
4132     } else {
4133       VLOG(1) << "Pooling cache miss";
4134       assert(workspace_allocator);
4135       auto allocated =
4136           workspace_allocator->AllocateBytes(workspace_size_in_bytes);
4137       if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
4138         LOG(ERROR) << "Failed to allocate backward pooling workspace";
4139         return false;
4140       }
4141       DeviceMemory<uint8> dest2;  // duplicated dest from forward:
4142       int64 dest2_size = 0;
4143 
4144       // miopen requires the strides and dims to be ordered as BDYX.
4145       std::vector<int64> dims64 =
4146           output_dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4147       // miopen does not use strides and must have 4D tensor.
4148       // std::vector<int> dims(pooling_dimensions.ndims() + 2);
4149 
4150       dest2_size = sizeof(T);
4151       for (auto& x : dims64) dest2_size *= x;
4152 
4153       if (dest2_size > 0) {
4154         assert(workspace_allocator);
4155         auto allocated = workspace_allocator->AllocateBytes(dest2_size);
4156         if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
4157           LOG(ERROR) << "Failed to allocate backward pooling workspace";
4158           return false;
4159         }
4160       } else {
4161         LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
4162                       "backward pooling";
4163       }
4164 
4165       status = wrap::miopenPoolingForward(
4166           miopen.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
4167           input_data.opaque(), &beta, dest_desc.handle(), dest2.opaque(), true,
4168           workspace.opaque(), workspace_size_in_bytes);
4169 
4170       if (status != miopenStatusSuccess) {
4171         LOG(ERROR)
4172             << "failed to enqueue forward pooling (before backward) on stream: "
4173             << ToString(status);
4174         return false;
4175       }
4176       workspace_ptr = reinterpret_cast<uint8*>(workspace.opaque());
4177     }
4178   }
4179   status = wrap::miopenPoolingBackward(
4180       miopen.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
4181       output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
4182       src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
4183       output_diff_data->opaque(), workspace_ptr);
4184 
4185   if (status != miopenStatusSuccess) {
4186     LOG(ERROR) << "failed to enqueue backward pooling on stream: "
4187                << ToString(status);
4188     return false;
4189   }
4190 
4191   return true;
4192 }
4193 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<double> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<double> & output_data,const DeviceMemory<double> & input_diff_data,DeviceMemory<double> * output_diff_data,ScratchAllocator * workspace_allocator)4194 bool MIOpenSupport::DoPoolBackward(
4195     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4196     const dnn::BatchDescriptor& input_dimensions,
4197     const DeviceMemory<double>& input_data,
4198     const dnn::BatchDescriptor& output_dimensions,
4199     const DeviceMemory<double>& output_data,
4200     const DeviceMemory<double>& input_diff_data,
4201     DeviceMemory<double>* output_diff_data,
4202     ScratchAllocator* workspace_allocator) {
4203   LOG(ERROR) << "miopen does not support backward pooling on double type yet";
4204   return false;
4205 }
4206 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<float> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<float> & output_data,const DeviceMemory<float> & input_diff_data,DeviceMemory<float> * output_diff_data,ScratchAllocator * workspace_allocator)4207 bool MIOpenSupport::DoPoolBackward(
4208     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4209     const dnn::BatchDescriptor& input_dimensions,
4210     const DeviceMemory<float>& input_data,
4211     const dnn::BatchDescriptor& output_dimensions,
4212     const DeviceMemory<float>& output_data,
4213     const DeviceMemory<float>& input_diff_data,
4214     DeviceMemory<float>* output_diff_data,
4215     ScratchAllocator* workspace_allocator) {
4216   return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
4217                             input_data, output_dimensions, output_data,
4218                             input_diff_data, output_diff_data,
4219                             workspace_allocator);
4220 }
4221 
DoPoolBackward(Stream * stream,const dnn::PoolingDescriptor & pooling_dimensions,const dnn::BatchDescriptor & input_dimensions,const DeviceMemory<Eigen::half> & input_data,const dnn::BatchDescriptor & output_dimensions,const DeviceMemory<Eigen::half> & output_data,const DeviceMemory<Eigen::half> & input_diff_data,DeviceMemory<Eigen::half> * output_diff_data,ScratchAllocator * workspace_allocator)4222 bool MIOpenSupport::DoPoolBackward(
4223     Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
4224     const dnn::BatchDescriptor& input_dimensions,
4225     const DeviceMemory<Eigen::half>& input_data,
4226     const dnn::BatchDescriptor& output_dimensions,
4227     const DeviceMemory<Eigen::half>& output_data,
4228     const DeviceMemory<Eigen::half>& input_diff_data,
4229     DeviceMemory<Eigen::half>* output_diff_data,
4230     ScratchAllocator* workspace_allocator) {
4231   return DoPoolBackwardImpl(stream, pooling_dimensions, input_dimensions,
4232                             input_data, output_dimensions, output_data,
4233                             input_diff_data, output_diff_data,
4234                             workspace_allocator);
4235 }
4236 
DoNormalizeWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,DeviceMemory<float> * output_data)4237 bool MIOpenSupport::DoNormalizeWithDimensions(
4238     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4239     const dnn::BatchDescriptor& dimensions,
4240     const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
4241   // Check for unsupported modes.
4242   if (normalize_descriptor.wrap_around()) {
4243     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
4244     return false;
4245   }
4246   if (normalize_descriptor.segment_size()) {
4247     LOG(ERROR) << "MIOpen LRN does not support segmentation";
4248     return false;
4249   }
4250 
4251   auto miopen = miopen_->GetHandle(parent_, stream);
4252 
4253   // Launch the normalization.
4254   ScopedTensorDescriptor dims{dimensions, miopenFloat};
4255   ScopedNormalizeDescriptor normalize{normalize_descriptor};
4256 
4257   // Alpha is the scaling factor for input.
4258   float alpha = 1.0f;
4259   // Beta is the scaling factor for output.
4260   float beta = 0.0f;
4261 
4262   auto status = wrap::miopenLRNForward(
4263       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
4264       input_data.opaque(), &beta, dims.handle(), output_data->opaque(), false,
4265       nullptr);
4266   if (status != miopenStatusSuccess) {
4267     LOG(ERROR) << "failed to run miopenLRNForward";
4268     return false;
4269   }
4270   return true;
4271 }
4272 
DoNormalizeBackwardWithDimensions(Stream * stream,const dnn::NormalizeDescriptor & normalize_descriptor,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & raw_data,const DeviceMemory<float> & normalized_data,const DeviceMemory<float> & normalized_variable_gradient,DeviceMemory<float> * raw_variable_gradient,ScratchAllocator * workspace_allocator)4273 bool MIOpenSupport::DoNormalizeBackwardWithDimensions(
4274     Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
4275     const dnn::BatchDescriptor& dimensions, const DeviceMemory<float>& raw_data,
4276     const DeviceMemory<float>& normalized_data,
4277     const DeviceMemory<float>& normalized_variable_gradient,
4278     DeviceMemory<float>* raw_variable_gradient,
4279     ScratchAllocator* workspace_allocator) {
4280   // Check for unsupported modes.
4281   if (normalize_descriptor.wrap_around()) {
4282     LOG(ERROR) << "MIOpen LRN does not support wrap-around mode";
4283     return false;
4284   }
4285   if (normalize_descriptor.segment_size()) {
4286     LOG(ERROR) << "MIOpen LRN does not support segmentation";
4287     return false;
4288   }
4289 
4290   auto miopen = miopen_->GetHandle(parent_, stream);
4291 
4292   ScopedTensorDescriptor dims{dimensions, miopenFloat};
4293   ScopedNormalizeDescriptor normalize{normalize_descriptor};
4294 
4295   float alpha = 1.0f;
4296   float beta = 0.0f;
4297 
4298   DeviceMemory<uint8> workspace;
4299   size_t workspace_size_in_bytes = 0;
4300   auto status =
4301       wrap::miopenLRNGetWorkSpaceSize(dims.handle(), &workspace_size_in_bytes);
4302 
4303   if (status != miopenStatusSuccess) {
4304     LOG(ERROR) << "failed to obtain workspace size for miopenLRNBackward";
4305     return false;
4306   }
4307 
4308   // Allocate the workspace.
4309   if (workspace_size_in_bytes > 0) {
4310     assert(workspace_allocator);
4311     auto allocated =
4312         workspace_allocator->AllocateBytes(workspace_size_in_bytes);
4313     if (!allocated.ok() || (workspace = allocated.ValueOrDie()) == nullptr) {
4314       LOG(ERROR) << "Failed to allocate backward pooling workspace";
4315       return false;
4316     }
4317   }
4318 
4319   DeviceMemory<uint8> dest2;  // duplicated dest from forward:
4320   int dest2_size = 0;
4321 
4322   // miopen requires the strides and dims to be ordered as BDYX.
4323   std::vector<int64> dims64 =
4324       dimensions.full_dims(dnn::DataLayout::kBatchDepthYX);
4325 
4326   // miopen does not use strides and must have 4D tensor.
4327   std::vector<int> dimsint(4);
4328 
4329   std::transform(dims64.cbegin(), dims64.cend(), dimsint.begin(),
4330                  &CheckedNarrowing<int64, int>);
4331 
4332   dest2_size =
4333       dimsint[0] * dimsint[1] * dimsint[2] * dimsint[3] * sizeof(float);
4334 
4335   if (dest2_size > 0) {
4336     assert(workspace_allocator);
4337     auto allocated = workspace_allocator->AllocateBytes(dest2_size);
4338     if (!allocated.ok() || (dest2 = allocated.ValueOrDie()) == nullptr) {
4339       LOG(ERROR)
4340           << "Failed to allocate tensor to chain forward and backward LRN";
4341       return false;
4342     }
4343   } else {
4344     LOG(ERROR) << "Failed to calculate tensor size to chain forward and "
4345                   "backward LRN";
4346   }
4347 
4348   status = wrap::miopenLRNForward(miopen.handle(), normalize.handle(), &alpha,
4349                                   dims.handle(), raw_data.opaque(), &beta,
4350                                   dims.handle(), dest2.opaque(), true,
4351                                   workspace.opaque());
4352 
4353   if (status != miopenStatusSuccess) {
4354     LOG(ERROR) << "failed to run miopenLRNForward";
4355     return false;
4356   }
4357 
4358   status = wrap::miopenLRNBackward(
4359       miopen.handle(), normalize.handle(), &alpha, dims.handle(),
4360       normalized_data.opaque(), dims.handle(),
4361       normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
4362       &beta, dims.handle(), raw_variable_gradient->opaque(),
4363       workspace.opaque());
4364 
4365   if (status != miopenStatusSuccess) {
4366     LOG(ERROR) << "failed to run miopenLRNBackward";
4367     return false;
4368   }
4369   return true;
4370 }
4371 
DoDepthConcatenate(Stream * stream,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,DeviceMemory<float> * output_data)4372 bool MIOpenSupport::DoDepthConcatenate(
4373     Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4374     port::ArraySlice<const DeviceMemory<float>*> input_data,
4375     DeviceMemory<float>* output_data) {
4376   CHECK_EQ(input_dimensions.size(), input_data.size());
4377 
4378   for (const auto& dimensions : input_dimensions) {
4379     if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
4380       LOG(ERROR) << "MIOpenSupport::DoDepthConcatenate currently only "
4381                     "supports the kBatchDepthYX layout.";
4382       return false;
4383     }
4384   }
4385 
4386   if (input_dimensions.empty()) {
4387     return true;  // Nothing to do.
4388   }
4389 
4390   dnn::BatchDescriptor output_dimensions =
4391       dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
4392 
4393   const int64 area = output_dimensions.width() * output_dimensions.height();
4394   const auto index = [area](int64 batch, int64 depth, int64 yx,
4395                             int64 max_depth) {
4396     return (batch * max_depth + depth) * area + yx;
4397   };
4398 
4399   std::vector<float> output_host(output_dimensions.ElementCount());
4400   std::vector<float> tmp;
4401   int64 depth_sum = 0;
4402   for (size_t i = 0; i < input_data.size(); ++i) {
4403     const auto& dimensions = input_dimensions[i];
4404     tmp.resize(dimensions.ElementCount());
4405     stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
4406     port::Status block_status = stream->BlockHostUntilDone();
4407     if (!block_status.ok()) {
4408       LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
4409       return false;
4410     }
4411 
4412     for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
4413       for (int64 yx = 0; yx < area; ++yx) {
4414         for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
4415           LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
4416                     << yx << ' ' << depth;
4417           output_host[index(batch, depth + depth_sum, yx,
4418                             output_dimensions.feature_map_count())] =
4419               tmp[index(batch, depth, yx, dimensions.feature_map_count())];
4420         }
4421       }
4422     }
4423     depth_sum += dimensions.feature_map_count();
4424   }
4425   stream->ThenMemcpyH2D<float>(output_host, output_data);
4426   return true;
4427 }
4428 
DoElementwiseOperate(Stream * stream,dnn::ElementwiseOperation operation,port::ArraySlice<dnn::BatchDescriptor> input_dimensions,port::ArraySlice<const DeviceMemory<float> * > input_data,const dnn::BatchDescriptor & output_dimensions,DeviceMemory<float> * output_data)4429 bool MIOpenSupport::DoElementwiseOperate(
4430     Stream* stream, dnn::ElementwiseOperation operation,
4431     port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
4432     port::ArraySlice<const DeviceMemory<float>*> input_data,
4433     const dnn::BatchDescriptor& output_dimensions,
4434     DeviceMemory<float>* output_data) {
4435   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4436   return false;
4437 }
4438 
DoXYPad(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_pad,int64 right_pad,int64 top_pad,int64 bottom_pad,DeviceMemory<float> * output_data)4439 bool MIOpenSupport::DoXYPad(Stream* stream,
4440                             const dnn::BatchDescriptor& dimensions,
4441                             const DeviceMemory<float>& input_data,
4442                             int64 left_pad, int64 right_pad, int64 top_pad,
4443                             int64 bottom_pad,
4444                             DeviceMemory<float>* output_data) {
4445   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4446   return false;
4447 }
4448 
DoXYSlice(Stream * stream,const dnn::BatchDescriptor & dimensions,const DeviceMemory<float> & input_data,int64 left_trim,int64 right_trim,int64 top_trim,int64 bottom_trim,DeviceMemory<float> * output_data)4449 bool MIOpenSupport::DoXYSlice(Stream* stream,
4450                               const dnn::BatchDescriptor& dimensions,
4451                               const DeviceMemory<float>& input_data,
4452                               int64 left_trim, int64 right_trim, int64 top_trim,
4453                               int64 bottom_trim,
4454                               DeviceMemory<float>* output_data) {
4455   LOG(FATAL) << "not yet implemented";  // TODO(leary)
4456   return false;
4457 }
4458 
DoMemcpyD2HQuantized(Stream * stream,const DeviceMemory<float> & gpu_unquantized_src,dnn::QuantizedActivationMode mode,void * host_dst,int64 size)4459 bool MIOpenSupport::DoMemcpyD2HQuantized(
4460     Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
4461     dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
4462   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
4463   return false;
4464 }
4465 
DoMemcpyH2DQuantized(Stream * stream,const void * host_src,int64 size,dnn::QuantizedActivationMode mode,DeviceMemory<float> * gpu_unquantized_dst)4466 bool MIOpenSupport::DoMemcpyH2DQuantized(
4467     Stream* stream, const void* host_src, int64 size,
4468     dnn::QuantizedActivationMode mode,
4469     DeviceMemory<float>* gpu_unquantized_dst) {
4470   LOG(ERROR) << "quantized memcpy not supported by MIOpen";
4471   return false;
4472 }
4473 
DeriveOutputBatchDescriptor(const BatchDescriptor & batch_descriptor,const FilterDescriptor & filter_descriptor,const dnn::ConvolutionDescriptor & convolution_descriptor,dnn::BatchDescriptor * output_batch_descriptor)4474 bool MIOpenSupport::DeriveOutputBatchDescriptor(
4475     const BatchDescriptor& batch_descriptor,
4476     const FilterDescriptor& filter_descriptor,
4477     const dnn::ConvolutionDescriptor& convolution_descriptor,
4478     dnn::BatchDescriptor* output_batch_descriptor) {
4479   ScopedTensorDescriptor input_nd{batch_descriptor, miopenFloat};
4480   ScopedFilterDescriptor filter{filter_descriptor, miopenFloat};
4481   ScopedConvolutionDescriptor conv{convolution_descriptor, miopenFloat};
4482 
4483   int dn = batch_descriptor.ndims() + 2;
4484   std::vector<int> dims(dn);  // in BDYX
4485   auto status = wrap::miopenGetConvolutionNdForwardOutputDim(
4486       conv.handle(), input_nd.handle(), filter.handle(), &dn, dims.data());
4487   if (status != miopenStatusSuccess) {
4488     LOG(ERROR) << "could not get output tensor for convolution: "
4489                << ToString(status);
4490     return false;
4491   }
4492 
4493   output_batch_descriptor->set_count(dims[0])
4494       .set_feature_map_count(dims[1])
4495       .set_layout(batch_descriptor.layout());
4496 
4497   for (int i = 0; i < batch_descriptor.ndims(); i++) {
4498     output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
4499                                              dims.rbegin()[i]);
4500   }
4501 
4502   return true;
4503 }
4504 
4505 template <typename T>
DoFusedConvolutionBiasActivationImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<T> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<T> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<T> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<T> * output_data,dnn::ProfileResult * output_profile_result)4506 bool MIOpenSupport::DoFusedConvolutionBiasActivationImpl(
4507     Stream* stream,
4508     int miopen_type,  // Actually miopenDataType_t.
4509     const dnn::BatchDescriptor& conv_input_descriptor,
4510     const DeviceMemory<T>& conv_input_data,
4511     const dnn::FilterDescriptor& filter_descriptor,
4512     const DeviceMemory<T>& filter_data,
4513     const dnn::ConvolutionDescriptor& convolution_descriptor,
4514     const dnn::BatchDescriptor& bias_descriptor,
4515     const DeviceMemory<T>& bias_data, dnn::ActivationMode activation_mode,
4516     const dnn::BatchDescriptor& output_descriptor, DeviceMemory<T>* output_data,
4517     dnn::ProfileResult* output_profile_result) {
4518   auto miopen = miopen_->GetHandle(parent_, stream);
4519 
4520   ScopedTensorDescriptor conv_input_nd{
4521       conv_input_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4522 
4523   ScopedTensorDescriptor bias_nd{bias_descriptor,
4524                                  static_cast<miopenDataType_t>(miopen_type)};
4525 
4526   ScopedTensorDescriptor output_nd{output_descriptor,
4527                                    static_cast<miopenDataType_t>(miopen_type)};
4528 
4529   ScopedConvolutionDescriptor conv{convolution_descriptor,
4530                                    static_cast<miopenDataType_t>(miopen_type)};
4531 
4532   ScopedFilterDescriptor filter{filter_descriptor,
4533                                 static_cast<miopenDataType_t>(miopen_type)};
4534 
4535   ScopedActivationDescriptor activation_desc{activation_mode};
4536 
4537   ScopedFusionPlanConvolutionBiasActivation fusion_plan{
4538       miopen.handle(), conv_input_nd.handle(), filter.handle(),
4539       conv.handle(),   bias_nd.handle(),       activation_desc};
4540 
4541   bool retval = false;
4542 
4543   if (fusion_plan.CompilationSucceeded()) {
4544     const bool is_profiling = output_profile_result != nullptr;
4545 
4546     std::unique_ptr<GpuTimer> timer;
4547     if (is_profiling) {
4548       timer.reset(new GpuTimer(parent_));
4549       timer->Init();
4550       timer->Start(AsGpuStream(stream));
4551     }
4552 
4553     miopenStatus_t status = miopenStatusSuccess;
4554 
4555     if (status == miopenStatusSuccess) {
4556       fusion_plan.SetConvolutionArgs(filter_data.opaque());
4557     }
4558 
4559     if (status == miopenStatusSuccess) {
4560       status = fusion_plan.SetBiasArgs(bias_data.opaque());
4561     }
4562 
4563     if (status == miopenStatusSuccess) {
4564       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4565     }
4566 
4567     if (status == miopenStatusSuccess) {
4568       status =
4569           fusion_plan.Execute(conv_input_nd.handle(), conv_input_data.opaque(),
4570                               output_nd.handle(), output_data->opaque());
4571     }
4572 
4573     if (is_profiling) {
4574       timer->Stop(AsGpuStream(stream));
4575       if (status == miopenStatusSuccess) {
4576         output_profile_result->set_elapsed_time_in_ms(
4577             timer->GetElapsedMilliseconds());
4578       }
4579       timer->Destroy();
4580     }
4581 
4582     if (status != miopenStatusSuccess) {
4583       // Silently return when we are profiling.
4584       if (!is_profiling) {
4585         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4586                    << ToString(status);
4587       }
4588     }
4589 
4590     retval = true;
4591   }
4592 
4593   return retval;
4594 }
4595 
DoFusedConvolutionBiasActivation(Stream * stream,const dnn::BatchDescriptor & conv_input_descriptor,const DeviceMemory<float> & conv_input_data,const dnn::FilterDescriptor & filter_descriptor,const DeviceMemory<float> & filter_data,const dnn::ConvolutionDescriptor & convolution_descriptor,const dnn::BatchDescriptor & bias_descriptor,const DeviceMemory<float> & bias_data,dnn::ActivationMode activation_mode,const dnn::BatchDescriptor & output_descriptor,DeviceMemory<float> * output_data,dnn::ProfileResult * output_profile_result)4596 bool MIOpenSupport::DoFusedConvolutionBiasActivation(
4597     Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
4598     const DeviceMemory<float>& conv_input_data,
4599     const dnn::FilterDescriptor& filter_descriptor,
4600     const DeviceMemory<float>& filter_data,
4601     const dnn::ConvolutionDescriptor& convolution_descriptor,
4602     const dnn::BatchDescriptor& bias_descriptor,
4603     const DeviceMemory<float>& bias_data, dnn::ActivationMode activation_mode,
4604     const dnn::BatchDescriptor& output_descriptor,
4605     DeviceMemory<float>* output_data,
4606     dnn::ProfileResult* output_profile_result) {
4607   return DoFusedConvolutionBiasActivationImpl<float>(
4608       stream, miopenFloat, conv_input_descriptor, conv_input_data,
4609       filter_descriptor, filter_data, convolution_descriptor, bias_descriptor,
4610       bias_data, activation_mode, output_descriptor, output_data,
4611       output_profile_result);
4612 }
4613 
4614 template <typename T, typename U>
DoFusedBatchNormActivationInferenceImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & mean_data,const DeviceMemory<U> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,dnn::ProfileResult * output_profile_result)4615 bool MIOpenSupport::DoFusedBatchNormActivationInferenceImpl(
4616     Stream* stream,
4617     int miopen_type,  // Actually miopenDataType_t.
4618     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4619     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4620     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4621     const DeviceMemory<U>& mean_data, const DeviceMemory<U>& variance_data,
4622     double epsilon, dnn::ActivationMode activation_mode,
4623     DeviceMemory<T>* y_data, dnn::ProfileResult* output_profile_result) {
4624   auto miopen = miopen_->GetHandle(parent_, stream);
4625 
4626   ScopedTensorDescriptor x_nd{x_descriptor,
4627                               static_cast<miopenDataType_t>(miopen_type)};
4628 
4629   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4630       scale_offset_mean_variance_descriptor,
4631       static_cast<miopenDataType_t>(miopen_type)};
4632 
4633   ScopedActivationDescriptor activation_desc{activation_mode};
4634 
4635   ScopedFusionPlanBatchNormActivationInference fusion_plan{
4636       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4637       activation_desc};
4638 
4639   bool retval = false;
4640 
4641   if (fusion_plan.CompilationSucceeded()) {
4642     const bool is_profiling = output_profile_result != nullptr;
4643 
4644     std::unique_ptr<GpuTimer> timer;
4645     if (is_profiling) {
4646       timer.reset(new GpuTimer(parent_));
4647       timer->Init();
4648       timer->Start(AsGpuStream(stream));
4649     }
4650 
4651     miopenStatus_t status = miopenStatusSuccess;
4652 
4653     if (status == miopenStatusSuccess) {
4654       fusion_plan.SetBatchNormInferenceArgs(
4655           scale_data.opaque(), offset_data.opaque(), mean_data.opaque(),
4656           variance_data.opaque(), epsilon);
4657     }
4658 
4659     if (status == miopenStatusSuccess) {
4660       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4661     }
4662 
4663     if (status == miopenStatusSuccess) {
4664       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4665                                    x_nd.handle(), y_data->opaque());
4666     }
4667 
4668     if (is_profiling) {
4669       timer->Stop(AsGpuStream(stream));
4670       if (status == miopenStatusSuccess) {
4671         output_profile_result->set_elapsed_time_in_ms(
4672             timer->GetElapsedMilliseconds());
4673       }
4674       timer->Destroy();
4675     }
4676 
4677     if (status != miopenStatusSuccess) {
4678       // Silently return when we are profiling.
4679       if (!is_profiling) {
4680         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4681                    << ToString(status);
4682       }
4683     }
4684 
4685     retval = true;
4686   }
4687 
4688   return retval;
4689 }
4690 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,dnn::ProfileResult * output_profile_result)4691 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4692     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4693     const DeviceMemory<float>& x_data,
4694     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4695     const DeviceMemory<float>& scale_data,
4696     const DeviceMemory<float>& offset_data,
4697     const DeviceMemory<float>& mean_data,
4698     const DeviceMemory<float>& variance_data, double epsilon,
4699     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4700     dnn::ProfileResult* output_profile_result) {
4701   return DoFusedBatchNormActivationInferenceImpl<float, float>(
4702       stream, miopenFloat, x_descriptor, x_data,
4703       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4704       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4705 }
4706 
DoFusedBatchNormActivationInference(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & mean_data,const DeviceMemory<float> & variance_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,dnn::ProfileResult * output_profile_result)4707 bool MIOpenSupport::DoFusedBatchNormActivationInference(
4708     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4709     const DeviceMemory<Eigen::half>& x_data,
4710     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4711     const DeviceMemory<float>& scale_data,
4712     const DeviceMemory<float>& offset_data,
4713     const DeviceMemory<float>& mean_data,
4714     const DeviceMemory<float>& variance_data, double epsilon,
4715     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4716     dnn::ProfileResult* output_profile_result) {
4717   return DoFusedBatchNormActivationInferenceImpl<Eigen::half, float>(
4718       stream, miopenHalf, x_descriptor, x_data,
4719       scale_offset_mean_variance_descriptor, scale_data, offset_data, mean_data,
4720       variance_data, epsilon, activation_mode, y_data, output_profile_result);
4721 }
4722 
4723 template <typename T, typename U>
DoFusedBatchNormActivationForwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<T> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<T> * y_data,DeviceMemory<U> * batch_mean_data,DeviceMemory<U> * batch_var_data,DeviceMemory<U> * saved_mean_data,DeviceMemory<U> * saved_var_data,dnn::ProfileResult * output_profile_result)4724 bool MIOpenSupport::DoFusedBatchNormActivationForwardImpl(
4725     Stream* stream,
4726     int miopen_type,  // Actually miopenDataType_t.
4727     const dnn::BatchDescriptor& x_descriptor, const DeviceMemory<T>& x_data,
4728     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4729     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4730     double epsilon, dnn::ActivationMode activation_mode,
4731     DeviceMemory<T>* y_data, DeviceMemory<U>* batch_mean_data,
4732     DeviceMemory<U>* batch_var_data, DeviceMemory<U>* saved_mean_data,
4733     DeviceMemory<U>* saved_var_data,
4734     dnn::ProfileResult* output_profile_result) {
4735   auto miopen = miopen_->GetHandle(parent_, stream);
4736 
4737   ScopedTensorDescriptor x_nd{x_descriptor,
4738                               static_cast<miopenDataType_t>(miopen_type)};
4739 
4740   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4741       scale_offset_mean_variance_descriptor,
4742       static_cast<miopenDataType_t>(miopen_type)};
4743 
4744   ScopedActivationDescriptor activation_desc{activation_mode};
4745 
4746   ScopedFusionPlanBatchNormActivationForward fusion_plan{
4747       miopen.handle(), x_nd.handle(), scale_offset_mean_variance_nd.handle(),
4748       activation_desc};
4749 
4750   bool retval = false;
4751 
4752   if (fusion_plan.CompilationSucceeded()) {
4753     const bool is_profiling = output_profile_result != nullptr;
4754 
4755     std::unique_ptr<GpuTimer> timer;
4756     if (is_profiling) {
4757       timer.reset(new GpuTimer(parent_));
4758       timer->Init();
4759       timer->Start(AsGpuStream(stream));
4760     }
4761 
4762     miopenStatus_t status = miopenStatusSuccess;
4763 
4764     if (status == miopenStatusSuccess) {
4765       fusion_plan.SetBatchNormForwardArgs(
4766           scale_data.opaque(), offset_data.opaque(), batch_mean_data->opaque(),
4767           batch_var_data->opaque(), saved_mean_data->opaque(),
4768           saved_var_data->opaque(), epsilon);
4769     }
4770 
4771     if (status == miopenStatusSuccess) {
4772       status = fusion_plan.SetActivationForwardArgs(activation_desc);
4773     }
4774 
4775     if (status == miopenStatusSuccess) {
4776       status = fusion_plan.Execute(x_nd.handle(), x_data.opaque(),
4777                                    x_nd.handle(), y_data->opaque());
4778     }
4779 
4780     if (is_profiling) {
4781       timer->Stop(AsGpuStream(stream));
4782       if (status == miopenStatusSuccess) {
4783         output_profile_result->set_elapsed_time_in_ms(
4784             timer->GetElapsedMilliseconds());
4785       }
4786       timer->Destroy();
4787     }
4788 
4789     if (status != miopenStatusSuccess) {
4790       // Silently return when we are profiling.
4791       if (!is_profiling) {
4792         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4793                    << ToString(status);
4794       }
4795     }
4796 
4797     retval = true;
4798   }
4799 
4800   return retval;
4801 }
4802 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<float> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<float> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4803 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4804     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4805     const DeviceMemory<float>& x_data,
4806     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4807     const DeviceMemory<float>& scale_data,
4808     const DeviceMemory<float>& offset_data, double epsilon,
4809     dnn::ActivationMode activation_mode, DeviceMemory<float>* y_data,
4810     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4811     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4812     dnn::ProfileResult* output_profile_result) {
4813   return DoFusedBatchNormActivationForwardImpl<float, float>(
4814       stream, miopenFloat, x_descriptor, x_data,
4815       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4816       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4817       saved_var_data, output_profile_result);
4818 }
4819 
DoFusedBatchNormActivationForward(Stream * stream,const dnn::BatchDescriptor & x_descriptor,const DeviceMemory<Eigen::half> & x_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,double epsilon,dnn::ActivationMode activation_mode,DeviceMemory<Eigen::half> * y_data,DeviceMemory<float> * batch_mean_data,DeviceMemory<float> * batch_var_data,DeviceMemory<float> * saved_mean_data,DeviceMemory<float> * saved_var_data,dnn::ProfileResult * output_profile_result)4820 bool MIOpenSupport::DoFusedBatchNormActivationForward(
4821     Stream* stream, const dnn::BatchDescriptor& x_descriptor,
4822     const DeviceMemory<Eigen::half>& x_data,
4823     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4824     const DeviceMemory<float>& scale_data,
4825     const DeviceMemory<float>& offset_data, double epsilon,
4826     dnn::ActivationMode activation_mode, DeviceMemory<Eigen::half>* y_data,
4827     DeviceMemory<float>* batch_mean_data, DeviceMemory<float>* batch_var_data,
4828     DeviceMemory<float>* saved_mean_data, DeviceMemory<float>* saved_var_data,
4829     dnn::ProfileResult* output_profile_result) {
4830   return DoFusedBatchNormActivationForwardImpl<Eigen::half, float>(
4831       stream, miopenHalf, x_descriptor, x_data,
4832       scale_offset_mean_variance_descriptor, scale_data, offset_data, epsilon,
4833       activation_mode, y_data, batch_mean_data, batch_var_data, saved_mean_data,
4834       saved_var_data, output_profile_result);
4835 }
4836 
4837 template <typename T, typename U>
DoFusedBatchNormActivationBackwardImpl(Stream * stream,int miopen_type,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<T> & y_act_backprop_data,const DeviceMemory<T> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<T> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<U> & scale_data,const DeviceMemory<U> & offset_data,const DeviceMemory<U> & saved_mean_data,const DeviceMemory<U> & saved_var_data,DeviceMemory<T> * x_bn_backprop_data,DeviceMemory<U> * scale_backprop_data,DeviceMemory<U> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4838 bool MIOpenSupport::DoFusedBatchNormActivationBackwardImpl(
4839     Stream* stream,
4840     int miopen_type,  // Actually miopenDataType_t.
4841     const dnn::BatchDescriptor& y_act_backprop_descriptor,
4842     const DeviceMemory<T>& y_act_backprop_data,
4843     const DeviceMemory<T>& y_act_data, dnn::ActivationMode activation_mode,
4844     const DeviceMemory<T>& x_bn_data,
4845     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4846     const DeviceMemory<U>& scale_data, const DeviceMemory<U>& offset_data,
4847     const DeviceMemory<U>& saved_mean_data,
4848     const DeviceMemory<U>& saved_var_data, DeviceMemory<T>* x_bn_backprop_data,
4849     DeviceMemory<U>* scale_backprop_data, DeviceMemory<U>* offset_backprop_data,
4850     dnn::ProfileResult* output_profile_result) {
4851   auto miopen = miopen_->GetHandle(parent_, stream);
4852 
4853   ScopedTensorDescriptor y_act_backprop_nd{
4854       y_act_backprop_descriptor, static_cast<miopenDataType_t>(miopen_type)};
4855 
4856   ScopedTensorDescriptor scale_offset_mean_variance_nd{
4857       scale_offset_mean_variance_descriptor,
4858       static_cast<miopenDataType_t>(miopen_type)};
4859 
4860   ScopedActivationDescriptor activation_desc{activation_mode};
4861 
4862   ScopedFusionPlanBatchNormActivationBackward fusion_plan{
4863       miopen.handle(), y_act_backprop_nd.handle(),
4864       scale_offset_mean_variance_nd.handle(), activation_desc};
4865 
4866   bool retval = false;
4867 
4868   if (fusion_plan.CompilationSucceeded()) {
4869     const bool is_profiling = output_profile_result != nullptr;
4870 
4871     std::unique_ptr<GpuTimer> timer;
4872     if (is_profiling) {
4873       timer.reset(new GpuTimer(parent_));
4874       timer->Init();
4875       timer->Start(AsGpuStream(stream));
4876     }
4877 
4878     miopenStatus_t status = miopenStatusSuccess;
4879 
4880     if (status == miopenStatusSuccess) {
4881       fusion_plan.SetBatchNormBackwardArgs(
4882           x_bn_data.opaque(), scale_data.opaque(), offset_data.opaque(),
4883           saved_mean_data.opaque(), saved_var_data.opaque(),
4884           scale_backprop_data->opaque(), offset_backprop_data->opaque());
4885     }
4886 
4887     if (status == miopenStatusSuccess) {
4888       status = fusion_plan.SetActivationBackwardArgs(activation_desc,
4889                                                      y_act_data.opaque());
4890     }
4891 
4892     if (status == miopenStatusSuccess) {
4893       status = fusion_plan.Execute(
4894           y_act_backprop_nd.handle(), y_act_backprop_data.opaque(),
4895           y_act_backprop_nd.handle(), x_bn_backprop_data->opaque());
4896     }
4897 
4898     if (is_profiling) {
4899       timer->Stop(AsGpuStream(stream));
4900       if (status == miopenStatusSuccess) {
4901         output_profile_result->set_elapsed_time_in_ms(
4902             timer->GetElapsedMilliseconds());
4903       }
4904       timer->Destroy();
4905     }
4906 
4907     if (status != miopenStatusSuccess) {
4908       // Silently return when we are profiling.
4909       if (!is_profiling) {
4910         LOG(FATAL) << "failed to enqueue fused-convolution on stream: "
4911                    << ToString(status);
4912       }
4913     }
4914 
4915     retval = true;
4916   }
4917 
4918   return retval;
4919 }
4920 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<float> & y_act_backprop_data,const DeviceMemory<float> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<float> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<float> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4921 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
4922     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
4923     const DeviceMemory<float>& y_act_backprop_data,
4924     const DeviceMemory<float>& y_act_data, dnn::ActivationMode activation_mode,
4925     const DeviceMemory<float>& x_bn_data,
4926     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4927     const DeviceMemory<float>& scale_data,
4928     const DeviceMemory<float>& offset_data,
4929     const DeviceMemory<float>& saved_mean_data,
4930     const DeviceMemory<float>& saved_var_data,
4931     DeviceMemory<float>* x_bn_backprop_data,
4932     DeviceMemory<float>* scale_backprop_data,
4933     DeviceMemory<float>* offset_backprop_data,
4934     dnn::ProfileResult* output_profile_result) {
4935   return DoFusedBatchNormActivationBackwardImpl<float, float>(
4936       stream, miopenFloat, y_act_backprop_descriptor, y_act_backprop_data,
4937       y_act_data, activation_mode, x_bn_data,
4938       scale_offset_mean_variance_descriptor, scale_data, offset_data,
4939       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
4940       offset_backprop_data, output_profile_result);
4941 }
4942 
DoFusedBatchNormActivationBackward(Stream * stream,const dnn::BatchDescriptor & y_act_backprop_descriptor,const DeviceMemory<Eigen::half> & y_act_backprop_data,const DeviceMemory<Eigen::half> & y_act_data,dnn::ActivationMode activation_mode,const DeviceMemory<Eigen::half> & x_bn_data,const dnn::BatchDescriptor & scale_offset_mean_variance_descriptor,const DeviceMemory<float> & scale_data,const DeviceMemory<float> & offset_data,const DeviceMemory<float> & saved_mean_data,const DeviceMemory<float> & saved_var_data,DeviceMemory<Eigen::half> * x_bn_backprop_data,DeviceMemory<float> * scale_backprop_data,DeviceMemory<float> * offset_backprop_data,dnn::ProfileResult * output_profile_result)4943 bool MIOpenSupport::DoFusedBatchNormActivationBackward(
4944     Stream* stream, const dnn::BatchDescriptor& y_act_backprop_descriptor,
4945     const DeviceMemory<Eigen::half>& y_act_backprop_data,
4946     const DeviceMemory<Eigen::half>& y_act_data,
4947     dnn::ActivationMode activation_mode,
4948     const DeviceMemory<Eigen::half>& x_bn_data,
4949     const dnn::BatchDescriptor& scale_offset_mean_variance_descriptor,
4950     const DeviceMemory<float>& scale_data,
4951     const DeviceMemory<float>& offset_data,
4952     const DeviceMemory<float>& saved_mean_data,
4953     const DeviceMemory<float>& saved_var_data,
4954     DeviceMemory<Eigen::half>* x_bn_backprop_data,
4955     DeviceMemory<float>* scale_backprop_data,
4956     DeviceMemory<float>* offset_backprop_data,
4957     dnn::ProfileResult* output_profile_result) {
4958   return DoFusedBatchNormActivationBackwardImpl<Eigen::half, float>(
4959       stream, miopenHalf, y_act_backprop_descriptor, y_act_backprop_data,
4960       y_act_data, activation_mode, x_bn_data,
4961       scale_offset_mean_variance_descriptor, scale_data, offset_data,
4962       saved_mean_data, saved_var_data, x_bn_backprop_data, scale_backprop_data,
4963       offset_backprop_data, output_profile_result);
4964 }
4965 
4966 }  // namespace gpu
4967 
initialize_miopen()4968 void initialize_miopen() {
4969   auto miopenAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
4970       rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
4971 
4972   if (!miopenAlreadyRegistered) {
4973     port::Status status =
4974         PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
4975             rocm::kROCmPlatformId, gpu::kMIOpenPlugin, "MIOpen",
4976             [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
4977               gpu::GpuExecutor* rocm_executor =
4978                   dynamic_cast<gpu::GpuExecutor*>(parent);
4979               if (rocm_executor == nullptr) {
4980                 LOG(ERROR)
4981                     << "Attempting to initialize an instance of the MIOpen "
4982                     << "support library with a non-ROCM StreamExecutor";
4983                 return nullptr;
4984               }
4985 
4986               gpu::MIOpenSupport* dnn = new gpu::MIOpenSupport(rocm_executor);
4987               if (!dnn->Init().ok()) {
4988                 // Note: Init() will log a more specific error.
4989                 delete dnn;
4990                 return nullptr;
4991               }
4992               return dnn;
4993             });
4994 
4995     if (!status.ok()) {
4996       LOG(ERROR) << "Unable to register MIOpen factory: "
4997                  << status.error_message();
4998     }
4999 
5000     PluginRegistry::Instance()->SetDefaultFactory(
5001         rocm::kROCmPlatformId, PluginKind::kDnn, gpu::kMIOpenPlugin);
5002   }
5003 }
5004 
5005 }  // namespace stream_executor
5006 
5007 REGISTER_MODULE_INITIALIZER(register_miopen,
5008                             { stream_executor::initialize_miopen(); });
5009