1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_ 19 20 #include <cuda_runtime_api.h> 21 #include <vector> 22 #include <string> 23 #include <limits> 24 #include <map> 25 #include "plugin/device/gpu/kernel/gpu_kernel.h" 26 #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" 27 #include "plugin/device/gpu/hal/device/gpu_memory_allocator.h" 28 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cuh" 29 namespace mindspore { 30 namespace kernel { 31 constexpr size_t kPrevOutput0th = 0; 32 constexpr size_t kPrevOutput1st = 1; 33 constexpr size_t kPrevOutput2nd = 2; 34 constexpr size_t kPrevOutput3rd = 3; 35 constexpr size_t kProbDimSize = 3; 36 constexpr size_t kIndicesDimSize = 2; 37 constexpr size_t kInputIdxForProbs = 0; 38 constexpr size_t kInputIdxForLabelIndices = 1; 39 constexpr size_t kInputIdxForLabelValues = 2; 40 constexpr size_t kInputIdxForSeqLen = 3; 41 constexpr size_t kWsIdxForSoftmaxProbs = 0; 42 constexpr size_t kWsIdxForCumLabelLen = 1; 43 constexpr size_t kWsIdxForLabelSquenceLen = 2; 44 constexpr size_t kWsIdxForLabelValueSp = 3; 45 constexpr size_t kWsIdxForLabelValuePcr = 4; 46 constexpr size_t kWsIdxForProbNum = 5; 47 constexpr size_t kWsIdxForPrecumLabelLen = 6; 48 constexpr size_t kWsIdxForMaxLabelLen = 7; 49 constexpr size_t kProbDimsIdxForMaxTime = 0; 50 constexpr size_t kProbDimsIdxForBatch = 1; 51 constexpr size_t kProbDimsIdxForNumClass = 2; 52 constexpr size_t kCTCLossInputsNum = 4; 53 constexpr size_t kCTCLossOutputsNum = 2; 54 55 template <typename T> 56 class CtcLossGpuKernelMod : public NativeGpuKernelMod { 57 public: CtcLossGpuKernelMod()58 CtcLossGpuKernelMod() 59 : label_indice_size_(0), 60 label_size_(0), 61 sequence_lengths_size_(0), 62 preprocess_collapse_repeated_(false), 63 ctc_merge_repeated_(true), 64 ignore_longer_outputs_than_inputs_(false), 65 is_null_input_(false), 66 kernel_name_("CTCLoss"), 67 probs(nullptr), 68 label_indices(nullptr), 69 label_values(nullptr), 70 sequence_length(nullptr), 71 costs(nullptr), 72 grads(nullptr), 73 softmax_probs(nullptr), 74 cum_labels_length(nullptr), 75 label_squence_length(nullptr), 76 label_value_sp(nullptr), 77 label_value_pcr(nullptr), 78 prob_num(nullptr), 79 precum_labels_length(nullptr), 80 max_labels_length(nullptr), 81 numclass(0), 82 batch(0), 83 max_time(0), 84 max_sequence(0), 85 max_labels_length_host(0), 86 batch_label(0), 87 label_value_with_blank(nullptr), 88 log_alpha_b(nullptr), 89 log_beta_b(nullptr) {} 90 ~CtcLossGpuKernelMod() override = default; 91 Launch(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)92 bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace, 93 const std::vector<KernelTensor *> &outputs, void *stream_ptr) override { 94 if (is_null_input_) { 95 return true; 96 } 97 LaunchInit(inputs, workspace, outputs); 98 LaunchFirstHalf(inputs, workspace, outputs, stream_ptr); 99 LaunchSecondHalf(inputs, workspace, outputs, stream_ptr); 100 return true; 101 } 102 Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)103 bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override { 104 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_); 105 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_); 106 107 PrimitivePtr prim = primitive_; 108 MS_EXCEPTION_IF_NULL(prim); 109 110 preprocess_collapse_repeated_ = GetValue<bool>(prim->GetAttr("preprocess_collapse_repeated")); 111 ctc_merge_repeated_ = GetValue<bool>(prim->GetAttr("ctc_merge_repeated")); 112 ignore_longer_outputs_than_inputs_ = GetValue<bool>(prim->GetAttr("ignore_longer_outputs_than_inputs")); 113 InitResource(); 114 return true; 115 } 116 Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)117 int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override { 118 if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) { 119 return ret; 120 } 121 auto shape_signed = inputs[kPrevOutput0th]->GetShapeVector(); 122 auto probs_shape = Convert2SizeTClipNeg(shape_signed); 123 auto indice_dims = inputs[kPrevOutput1st]->GetShapeVector(); 124 auto labels_dims = inputs[kPrevOutput2nd]->GetShapeVector(); 125 auto sequence_length_dims = inputs[kPrevOutput3rd]->GetShapeVector(); 126 is_null_input_ = CHECK_SHAPE_NULL(probs_shape, kernel_name_, "x") || 127 CHECK_SHAPE_NULL(indice_dims, kernel_name_, "labels_indices") || 128 CHECK_SHAPE_NULL(labels_dims, kernel_name_, "labels_values") || 129 CHECK_SHAPE_NULL(sequence_length_dims, kernel_name_, "sequence_length"); 130 if (is_null_input_) { 131 InitSizeLists(); 132 return true; 133 } 134 if (probs_shape.size() != kProbDimSize) { 135 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of x must be 3, but got " << probs_shape.size(); 136 } 137 probs_dims_[kProbDimsIdxForMaxTime] = probs_shape[kProbDimsIdxForMaxTime]; 138 probs_dims_[kProbDimsIdxForBatch] = probs_shape[kProbDimsIdxForBatch]; 139 probs_dims_[kProbDimsIdxForNumClass] = probs_shape[kProbDimsIdxForNumClass]; 140 141 if (labels_dims.size() != 1) { 142 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of labels_values must be 1, but got " 143 << labels_dims.size(); 144 } 145 if (indice_dims.size() != kIndicesDimSize) { 146 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of labels_indices must be 2, but got " 147 << indice_dims.size(); 148 } 149 label_size_ = sizeof(int); 150 label_size_ *= SizeOf(labels_dims); 151 label_indice_size_ = sizeof(int64_t); 152 label_indice_size_ *= SizeOf(indice_dims); 153 154 sequence_lengths_size_ = LongToSizeClipNeg(sequence_length_dims[0]) * sizeof(int); 155 InitSizeLists(); 156 return KRET_OK; 157 } 158 159 protected: LaunchInit(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs)160 void LaunchInit(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace, 161 const std::vector<KernelTensor *> &outputs) { 162 probs = GetDeviceAddress<T>(inputs, kInputIdxForProbs); 163 label_indices = GetDeviceAddress<int64_t>(inputs, kInputIdxForLabelIndices); 164 label_values = GetDeviceAddress<int>(inputs, kInputIdxForLabelValues); 165 sequence_length = GetDeviceAddress<int>(inputs, kInputIdxForSeqLen); 166 costs = GetDeviceAddress<T>(outputs, 0); 167 grads = GetDeviceAddress<T>(outputs, 1); 168 softmax_probs = GetDeviceAddress<T>(workspace, kWsIdxForSoftmaxProbs); 169 cum_labels_length = GetDeviceAddress<int>(workspace, kWsIdxForCumLabelLen); 170 label_squence_length = GetDeviceAddress<int>(workspace, kWsIdxForLabelSquenceLen); 171 label_value_sp = GetDeviceAddress<int>(workspace, kWsIdxForLabelValueSp); 172 label_value_pcr = GetDeviceAddress<int>(workspace, kWsIdxForLabelValuePcr); 173 prob_num = GetDeviceAddress<T>(workspace, kWsIdxForProbNum); 174 precum_labels_length = GetDeviceAddress<int>(workspace, kWsIdxForPrecumLabelLen); 175 max_labels_length = GetDeviceAddress<int>(workspace, kWsIdxForMaxLabelLen); 176 numclass = SizeToInt(probs_dims_[kProbDimsIdxForNumClass]); 177 batch = SizeToInt(probs_dims_[kProbDimsIdxForBatch]); 178 max_time = SizeToInt(probs_dims_[kProbDimsIdxForMaxTime]); 179 max_sequence = 0; 180 max_labels_length_host = 0; 181 batch_label = 0; 182 label_value_with_blank = nullptr; 183 log_alpha_b = nullptr; 184 log_beta_b = nullptr; 185 } 186 LaunchFirstHalf(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)187 void LaunchFirstHalf(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace, 188 const std::vector<KernelTensor *> &outputs, void *stream_ptr) { 189 cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr); 190 cudaError_t status = cudaErrorNotReady; 191 status = CalculateMaxSequence(sequence_length, max_labels_length, batch, stream); 192 CHECK_CUDA_STATUS(status, "CalculateMaxSequence called by " + kernel_name_); 193 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 194 cudaMemcpyAsync(&max_sequence, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream), 195 "cudaMemcpyAsync failed."); 196 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); 197 if (max_time < max_sequence) { 198 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the x[0] must be equal to or greater than max_sequence, " 199 << "but got x[0]: " << max_time << ", max_sequence: " << max_sequence; 200 } 201 status = InnerSoftMax(probs, softmax_probs, sequence_length, max_time, batch, numclass, stream); 202 CHECK_CUDA_STATUS(status, "InnerSoftMax called by " + kernel_name_); 203 MemsetForWS(label_value_pcr, cum_labels_length, label_squence_length, costs, grads, stream); 204 status = CalculatePreLength(label_squence_length, precum_labels_length, cum_labels_length, max_labels_length, 205 label_indices, batch, label_size_ / sizeof(int), stream); 206 CHECK_CUDA_STATUS(status, "CalculatePreLength called by " + kernel_name_); 207 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 208 cudaMemcpyAsync(&batch_label, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream), 209 "cudaMemcpyAsync failed."); 210 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); 211 if (batch != batch_label + 1) { 212 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the batch size of input must be equal to " 213 << (batch_label + 1) << ", but got " << batch; 214 } 215 status = GenLabelValue(label_value_sp, label_indices, label_values, label_squence_length, cum_labels_length, 216 max_labels_length, label_size_ / sizeof(int), numclass - 1, batch, stream); 217 CHECK_CUDA_STATUS(status, "GenLabelValue called by " + kernel_name_); 218 if (preprocess_collapse_repeated_) { 219 status = GenLabelValuePCR(label_value_sp, label_value_pcr, label_squence_length, cum_labels_length, 220 max_labels_length, batch, stream); 221 CHECK_CUDA_STATUS(status, "GenLabelValuePCR called by " + kernel_name_); 222 } 223 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 224 cudaMemcpyAsync(&max_labels_length_host, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream), 225 "cudaMemcpyAsync failed."); 226 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); 227 } 228 LaunchSecondHalf(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)229 void LaunchSecondHalf(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace, 230 const std::vector<KernelTensor *> &outputs, void *stream_ptr) { 231 cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr); 232 const int SOffSet = 2 * max_labels_length_host + 1; 233 int log_prob_size = batch * SOffSet * max_time; 234 cudaError_t status = cudaErrorNotReady; 235 236 if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) { 237 MS_LOG(EXCEPTION) << "output size is greater than input size."; 238 } 239 MemManageForCus(&log_alpha_b, &log_beta_b, &label_value_with_blank, cum_labels_length, log_prob_size, batch, 240 stream); 241 242 if (preprocess_collapse_repeated_) { 243 status = GenLabelWithBlank(label_value_pcr, label_value_with_blank, label_squence_length, precum_labels_length, 244 cum_labels_length, batch, numclass - 1, stream); 245 } else { 246 status = GenLabelWithBlank(label_value_sp, label_value_with_blank, label_squence_length, precum_labels_length, 247 cum_labels_length, batch, numclass - 1, stream); 248 } 249 CHECK_CUDA_STATUS(status, "GenLabelWithBlank called by " + kernel_name_); 250 251 status = CalculateFwdVar(log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated_, 252 batch, SOffSet, max_time, numclass - 1, label_squence_length, cum_labels_length, 253 ignore_longer_outputs_than_inputs_, stream); 254 CHECK_CUDA_STATUS(status, "CalculateFwdVar called by " + kernel_name_); 255 status = CalculateBwdVar(log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated_, 256 batch, SOffSet, max_time, numclass - 1, label_squence_length, cum_labels_length, 257 ignore_longer_outputs_than_inputs_, stream); 258 CHECK_CUDA_STATUS(status, "CalculateBwdVar called by " + kernel_name_); 259 status = CTCLoss(log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, max_time, numclass, 260 sequence_length, label_squence_length, cum_labels_length, costs, grads, prob_num, 261 ignore_longer_outputs_than_inputs_, stream); 262 CHECK_CUDA_STATUS(status, kernel_name_); 263 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); 264 FreeMem(label_value_with_blank, log_alpha_b, log_beta_b); 265 } 266 InitSizeLists()267 void InitSizeLists() { 268 workspace_size_list_.clear(); 269 output_size_list_.clear(); 270 workspace_size_list_.push_back(probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] * 271 probs_dims_[kProbDimsIdxForNumClass] * sizeof(T)); 272 workspace_size_list_.push_back(sequence_lengths_size_); 273 workspace_size_list_.push_back(sequence_lengths_size_); 274 workspace_size_list_.push_back(label_size_); 275 workspace_size_list_.push_back(label_size_); 276 workspace_size_list_.push_back(probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] * 277 probs_dims_[kProbDimsIdxForNumClass] * sizeof(T)); 278 workspace_size_list_.push_back(sequence_lengths_size_); 279 workspace_size_list_.push_back(sizeof(int)); 280 output_size_list_.push_back(probs_dims_[kProbDimsIdxForBatch] * sizeof(T)); 281 output_size_list_.push_back(probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] * 282 probs_dims_[kProbDimsIdxForNumClass] * sizeof(T)); 283 } MemsetForWS(int * label_value_pcr,int * cum_labels_length,int * label_squence_length,T * costs,T * grads,cudaStream_t stream)284 void MemsetForWS(int *label_value_pcr, int *cum_labels_length, int *label_squence_length, T *costs, T *grads, 285 cudaStream_t stream) { 286 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(label_value_pcr, static_cast<int>(0), label_size_, stream), 287 "cudaMemSet failed in CtcLossGpuKernelMod::Launch."); 288 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 289 cudaMemsetAsync(cum_labels_length, static_cast<int>(0), sequence_lengths_size_, stream), 290 "cudaMemSet failed in CtcLossGpuKernelMod::Launch."); 291 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 292 cudaMemsetAsync(label_squence_length, static_cast<int>(0), sequence_lengths_size_, stream), 293 "cudaMemSet failed in CtcLossGpuKernelMod::Launch."); 294 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 295 cudaMemsetAsync(costs, static_cast<T>(0), probs_dims_[kProbDimsIdxForBatch] * sizeof(T), stream), 296 "cudaMemSet failed in CtcLossGpuKernelMod::Launch."); 297 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 298 cudaMemsetAsync(grads, static_cast<T>(0), 299 probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] * 300 probs_dims_[kProbDimsIdxForNumClass] * sizeof(T), 301 stream), 302 "cudaMemSet failed in CtcLossGpuKernelMod::Launch."); 303 } MemManageForCus(T ** log_alpha_b,T ** log_beta_b,int ** label_value_with_blank,int * cum_labels_length,int log_prob_size,int batch,cudaStream_t stream)304 void MemManageForCus(T **log_alpha_b, T **log_beta_b, int **label_value_with_blank, int *cum_labels_length, 305 int log_prob_size, int batch, cudaStream_t stream) { 306 int total_labels_size_host = 0; 307 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMalloc(reinterpret_cast<void **>(log_alpha_b), sizeof(T) * log_prob_size), 308 "cudaMalloc failed."); 309 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMalloc(reinterpret_cast<void **>(log_beta_b), sizeof(T) * log_prob_size), 310 "cudaMalloc failed."); 311 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&total_labels_size_host, cum_labels_length + batch - 1, 312 sizeof(int), cudaMemcpyDeviceToHost, stream), 313 "cudaMemcpyAsync failed."); 314 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed."); 315 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( 316 cudaMalloc(reinterpret_cast<void **>(label_value_with_blank), sizeof(int) * (2 * total_labels_size_host + batch)), 317 "cudaMalloc failed."); 318 } 319 FreeMem(int * label_value_with_blank,T * log_alpha_b,T * log_beta_b)320 void FreeMem(int *label_value_with_blank, T *log_alpha_b, T *log_beta_b) { 321 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(label_value_with_blank), "cudaFree failed."); 322 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(log_alpha_b), "cudaFree failed."); 323 CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(log_beta_b), "cudaFree failed."); 324 } 325 326 size_t probs_dims_[3] = {0}; 327 int label_indice_size_; 328 int label_size_; 329 int sequence_lengths_size_; 330 bool preprocess_collapse_repeated_; 331 bool ctc_merge_repeated_; 332 bool ignore_longer_outputs_than_inputs_; 333 bool is_null_input_; 334 std::string kernel_name_; 335 T kLogZero_ = -std::numeric_limits<T>::infinity(); 336 337 // Heap parameter 338 T *probs; 339 int64_t *label_indices; 340 int *label_values; 341 int *sequence_length; 342 T *costs; 343 T *grads; 344 T *softmax_probs; 345 int *cum_labels_length; 346 int *label_squence_length; 347 int *label_value_sp; 348 int *label_value_pcr; 349 T *prob_num; 350 int *precum_labels_length; 351 int *max_labels_length; 352 int numclass; 353 int batch; 354 int max_time; 355 int max_sequence; 356 int max_labels_length_host; 357 int batch_label; 358 int *label_value_with_blank; 359 T *log_alpha_b; 360 T *log_beta_b; 361 }; // namespace kernel 362 } // namespace kernel 363 } // namespace mindspore 364 365 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_ 366