1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_UPDATE_THOR_GRADIENT_GPU_KERNEL_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_UPDATE_THOR_GRADIENT_GPU_KERNEL_H_ 19 #include <cublas_v2.h> 20 #include <cuda_runtime_api.h> 21 #include <vector> 22 #include <string> 23 #include "plugin/device/gpu/kernel/gpu_kernel.h" 24 #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" 25 #include "plugin/device/gpu/kernel/kernel_constants.h" 26 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/convert_gradient_impl.cuh" 27 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pad_impl.cuh" 28 #include "include/common/utils/convert_utils.h" 29 30 namespace mindspore { 31 namespace kernel { 32 struct GradientSize { 33 size_t batch_h; 34 size_t batch_w; 35 size_t h; 36 size_t w; 37 size_t ori_h; 38 size_t ori_w; 39 size_t pad_h; 40 size_t pad_w; 41 bool need_convert; 42 cudaDataType_t dtype; 43 }; 44 template <typename T> 45 class UpdateThorGradientGpuKernelMod : public NativeGpuKernelMod { 46 public: UpdateThorGradientGpuKernelMod()47 UpdateThorGradientGpuKernelMod() : split_dim(128), handle_(nullptr) {} 48 ~UpdateThorGradientGpuKernelMod() = default; 49 Launch(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)50 bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace, 51 const std::vector<KernelTensor *> &outputs, void *stream_ptr) override { 52 if (is_null_input_) { 53 return true; 54 } 55 CHECK_CUBLAS_RET_WITH_ERROR(cublasSetStream(handle_, reinterpret_cast<cudaStream_t>(stream_ptr)), 56 "cublasSetStream failed"); 57 auto input1_addr = GetDeviceAddress<T>(inputs, 0); 58 auto input2_addr = GetDeviceAddress<T>(inputs, 1); 59 auto input3_addr = GetDeviceAddress<T>(inputs, 2); 60 auto workspace1_addr = GetDeviceAddress<T>(workspace, 0); 61 T *workspace2_addr = nullptr; 62 T *workspace3_addr = nullptr; 63 if (gradient_size.need_convert) { 64 workspace2_addr = GetDeviceAddress<T>(workspace, 1); 65 workspace3_addr = GetDeviceAddress<T>(workspace, 2); 66 } 67 T *workspace4_addr = nullptr; 68 cudaError_t status = cudaErrorNotReady; 69 auto output_addr = GetDeviceAddress<T>(outputs, 0); 70 if (gradient_size.pad_h != 0 || gradient_size.pad_w != 0) { 71 workspace4_addr = GetDeviceAddress<T>(workspace, 3); 72 const size_t size = (gradient_size.ori_h + gradient_size.pad_h) * (gradient_size.ori_w + gradient_size.pad_w); 73 status = CalPad(size, input2_addr, 1, 1, gradient_size.ori_h, gradient_size.ori_w, 74 gradient_size.ori_h + gradient_size.pad_h, gradient_size.ori_w + gradient_size.pad_w, 0, 0, 0.0, 75 workspace4_addr, reinterpret_cast<cudaStream_t>(stream_ptr)); 76 CHECK_CUDA_STATUS(status, kernel_name_); 77 cudaMemsetAsync(workspace1_addr, 0, 78 gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * sizeof(T), 79 reinterpret_cast<cudaStream_t>(stream_ptr)); 80 input2_addr = workspace4_addr; 81 } 82 const float alpha = 1; 83 const float beta = 0; 84 const int lda = SizeToInt(gradient_size.h); 85 const int ldb = SizeToInt(gradient_size.ori_w + gradient_size.pad_w); 86 const int ldc = SizeToInt(gradient_size.ori_w + gradient_size.pad_w); 87 88 auto stride_a = SizeToInt(gradient_size.h * gradient_size.h); 89 auto stride_b = SizeToInt(gradient_size.h * (gradient_size.ori_w + gradient_size.pad_w)); 90 auto stride_c = SizeToInt(gradient_size.h * (gradient_size.ori_w + gradient_size.pad_w)); 91 92 try { 93 CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE( 94 cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(gradient_size.ori_w), 95 SizeToInt(gradient_size.h), SizeToInt(gradient_size.h), &alpha, input2_addr, 96 gradient_size.dtype, ldb, stride_b, input1_addr, gradient_size.dtype, lda, stride_a, 97 &beta, workspace1_addr, gradient_size.dtype, ldc, stride_c, gradient_size.batch_h, 98 CUDA_R_32F, algo_), 99 "cublasSgemm Call Fail"); 100 } catch (const std::exception &e) { 101 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", encountered an exception: " << e.what() 102 << " when invoke cubals cublasGemmStridedBatchedEx"; 103 } 104 105 auto r_input_addr = workspace1_addr; 106 if (gradient_size.need_convert) { 107 size_t size = gradient_size.batch_w * gradient_size.batch_h * gradient_size.w * gradient_size.h; 108 status = ConvertGradient(size, gradient_size.h, gradient_size.w, gradient_size.batch_w, 109 gradient_size.batch_w * gradient_size.w, workspace1_addr, workspace2_addr, 110 reinterpret_cast<cudaStream_t>(stream_ptr)); 111 CHECK_CUDA_STATUS(status, kernel_name_); 112 r_input_addr = workspace2_addr; 113 } 114 115 const int lda_r = SizeToInt(gradient_size.w); 116 const int ldb_r = SizeToInt(gradient_size.w); 117 const int ldc_r = SizeToInt(gradient_size.w); 118 119 stride_a = SizeToInt(gradient_size.h * gradient_size.w); 120 stride_b = SizeToInt(gradient_size.w * gradient_size.w); 121 stride_c = SizeToInt(gradient_size.h * gradient_size.w); 122 auto r_output_addr = output_addr; 123 if (gradient_size.need_convert) { 124 r_output_addr = workspace3_addr; 125 } 126 CHECK_CUBLAS_RET_WITH_EXCEPT_NOTRACE( 127 cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(gradient_size.w), 128 SizeToInt(gradient_size.h), SizeToInt(gradient_size.w), &alpha, input3_addr, 129 gradient_size.dtype, ldb_r, stride_b, r_input_addr, gradient_size.dtype, lda_r, 130 stride_a, &beta, r_output_addr, gradient_size.dtype, ldc_r, stride_c, 131 gradient_size.batch_h * gradient_size.batch_w, CUDA_R_32F, algo_), 132 "cublasSgemm Call Fail"); 133 if (gradient_size.need_convert) { 134 size_t size = gradient_size.batch_w * gradient_size.batch_h * gradient_size.w * gradient_size.h; 135 if (gradient_size.pad_h == 0 && gradient_size.pad_w == 0) { 136 status = ConvertGradientBack(size, gradient_size.h, gradient_size.w, gradient_size.batch_w, 137 gradient_size.batch_w * gradient_size.w, r_output_addr, output_addr, 138 reinterpret_cast<cudaStream_t>(stream_ptr)); 139 } else { 140 status = ConvertGradientBack(size, gradient_size.h, gradient_size.w, gradient_size.ori_h, gradient_size.ori_w, 141 gradient_size.batch_w, gradient_size.ori_w, r_output_addr, output_addr, 142 reinterpret_cast<cudaStream_t>(stream_ptr)); 143 } 144 CHECK_CUDA_STATUS(status, kernel_name_); 145 } 146 return true; 147 } 148 Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)149 bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override { 150 return true; 151 } 152 Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)153 int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override { 154 handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); 155 output_size_list_.clear(); 156 (void)SetProperty(primitive_, inputs, outputs); 157 InitSizeLists(); 158 return KRET_OK; 159 } 160 161 protected: InitSizeLists()162 void InitSizeLists() { 163 size_t unit_size = sizeof(T); 164 165 size_t output_size = gradient_size.ori_h * gradient_size.ori_w * unit_size; 166 output_size_list_.push_back(output_size); 167 168 size_t workspace_size_ = 169 gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; 170 workspace_size_list_.push_back(workspace_size_); 171 172 if (gradient_size.need_convert) { 173 workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; 174 workspace_size_list_.push_back(workspace_size_); 175 workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size; 176 workspace_size_list_.push_back(workspace_size_); 177 } 178 179 if (gradient_size.pad_h != 0 || gradient_size.pad_w != 0) { 180 workspace_size_ = 181 (gradient_size.ori_w + gradient_size.pad_w) * (gradient_size.ori_h + gradient_size.pad_h) * unit_size; 182 workspace_size_list_.push_back(workspace_size_); 183 } 184 } 185 186 private: SetProperty(const PrimitivePtr & primitive,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)187 void SetProperty(const PrimitivePtr &primitive, const std::vector<KernelTensor *> &inputs, 188 const std::vector<KernelTensor *> &outputs) { 189 const auto &matrix_a_shape = inputs[kIndex0]->GetShapeVector(); 190 const auto &shape_signed = inputs[kIndex1]->GetShapeVector(); 191 auto gradient_shape = Convert2SizeTClipNeg(shape_signed); 192 const auto &matrix_g_shape = inputs[kIndex2]->GetShapeVector(); 193 if (AnfAlgo::IsShapesDynamic({matrix_a_shape, shape_signed, matrix_g_shape})) { 194 return; 195 } 196 is_null_input_ = CHECK_SHAPE_NULL(matrix_a_shape, kernel_name_, "matrix_a") || 197 CHECK_SHAPE_NULL(gradient_shape, kernel_name_, "gradient") || 198 CHECK_SHAPE_NULL(matrix_g_shape, kernel_name_, "matrix_g"); 199 if (is_null_input_) { 200 return; 201 } 202 203 split_dim = LongToSize(GetValue<int64_t>(primitive->GetAttr("split_dim"))); 204 if (split_dim == 0) { 205 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", divide by zero, split_dim cannot be 0, but got " << split_dim; 206 } 207 gradient_size.batch_h = gradient_shape[0] / split_dim; 208 gradient_size.batch_w = gradient_shape[1] / split_dim; 209 if (gradient_size.batch_h * split_dim != gradient_shape[0]) { 210 gradient_size.batch_h += 1; 211 if (gradient_shape[0] > split_dim) { 212 gradient_size.h = split_dim; 213 gradient_size.pad_h = gradient_size.batch_h * split_dim - gradient_shape[0]; 214 } else { 215 gradient_size.h = gradient_shape[0]; 216 gradient_size.pad_h = 0; 217 } 218 } else { 219 gradient_size.h = split_dim; 220 gradient_size.pad_h = 0; 221 } 222 223 if (gradient_size.batch_w * split_dim != gradient_shape[1]) { 224 gradient_size.batch_w += 1; 225 if (gradient_shape[1] > split_dim) { 226 gradient_size.w = split_dim; 227 gradient_size.pad_w = gradient_size.batch_w * split_dim - gradient_shape[1]; 228 } else { 229 gradient_size.w = gradient_shape[1]; 230 gradient_size.pad_w = 0; 231 } 232 } else { 233 gradient_size.w = split_dim; 234 gradient_size.pad_w = 0; 235 } 236 237 if (gradient_size.batch_w * gradient_size.w <= split_dim) { 238 gradient_size.need_convert = false; 239 } else { 240 gradient_size.need_convert = true; 241 } 242 243 gradient_size.ori_w = gradient_shape[1]; 244 gradient_size.ori_h = gradient_shape[0]; 245 gradient_size.dtype = GetCudaDataType(TypeIdLabel(inputs[kIndex1]->dtype_id())); 246 } 247 248 size_t split_dim; 249 bool is_null_input_; 250 struct GradientSize gradient_size; 251 cublasHandle_t handle_; 252 cublasGemmAlgo_t algo_ = CUBLAS_GEMM_DEFAULT; 253 }; 254 } // namespace kernel 255 } // namespace mindspore 256 257 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_UPDATE_THOR_GRADIENT_GPU_KERNEL_H_ 258