• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
19 
20 #include <cuda_runtime_api.h>
21 #include <vector>
22 #include <string>
23 #include <limits>
24 #include <map>
25 #include "plugin/device/gpu/kernel/gpu_kernel.h"
26 #include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
27 #include "plugin/device/gpu/hal/device/gpu_memory_allocator.h"
28 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/ctcloss_impl.cuh"
29 namespace mindspore {
30 namespace kernel {
31 constexpr size_t kPrevOutput0th = 0;
32 constexpr size_t kPrevOutput1st = 1;
33 constexpr size_t kPrevOutput2nd = 2;
34 constexpr size_t kPrevOutput3rd = 3;
35 constexpr size_t kProbDimSize = 3;
36 constexpr size_t kIndicesDimSize = 2;
37 constexpr size_t kInputIdxForProbs = 0;
38 constexpr size_t kInputIdxForLabelIndices = 1;
39 constexpr size_t kInputIdxForLabelValues = 2;
40 constexpr size_t kInputIdxForSeqLen = 3;
41 constexpr size_t kWsIdxForSoftmaxProbs = 0;
42 constexpr size_t kWsIdxForCumLabelLen = 1;
43 constexpr size_t kWsIdxForLabelSquenceLen = 2;
44 constexpr size_t kWsIdxForLabelValueSp = 3;
45 constexpr size_t kWsIdxForLabelValuePcr = 4;
46 constexpr size_t kWsIdxForProbNum = 5;
47 constexpr size_t kWsIdxForPrecumLabelLen = 6;
48 constexpr size_t kWsIdxForMaxLabelLen = 7;
49 constexpr size_t kProbDimsIdxForMaxTime = 0;
50 constexpr size_t kProbDimsIdxForBatch = 1;
51 constexpr size_t kProbDimsIdxForNumClass = 2;
52 constexpr size_t kCTCLossInputsNum = 4;
53 constexpr size_t kCTCLossOutputsNum = 2;
54 
55 template <typename T>
56 class CtcLossGpuKernelMod : public NativeGpuKernelMod {
57  public:
CtcLossGpuKernelMod()58   CtcLossGpuKernelMod()
59       : label_indice_size_(0),
60         label_size_(0),
61         sequence_lengths_size_(0),
62         preprocess_collapse_repeated_(false),
63         ctc_merge_repeated_(true),
64         ignore_longer_outputs_than_inputs_(false),
65         is_null_input_(false),
66         kernel_name_("CTCLoss"),
67         probs(nullptr),
68         label_indices(nullptr),
69         label_values(nullptr),
70         sequence_length(nullptr),
71         costs(nullptr),
72         grads(nullptr),
73         softmax_probs(nullptr),
74         cum_labels_length(nullptr),
75         label_squence_length(nullptr),
76         label_value_sp(nullptr),
77         label_value_pcr(nullptr),
78         prob_num(nullptr),
79         precum_labels_length(nullptr),
80         max_labels_length(nullptr),
81         numclass(0),
82         batch(0),
83         max_time(0),
84         max_sequence(0),
85         max_labels_length_host(0),
86         batch_label(0),
87         label_value_with_blank(nullptr),
88         log_alpha_b(nullptr),
89         log_beta_b(nullptr) {}
90   ~CtcLossGpuKernelMod() override = default;
91 
Launch(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)92   bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
93               const std::vector<KernelTensor *> &outputs, void *stream_ptr) override {
94     if (is_null_input_) {
95       return true;
96     }
97     LaunchInit(inputs, workspace, outputs);
98     LaunchFirstHalf(inputs, workspace, outputs, stream_ptr);
99     LaunchSecondHalf(inputs, workspace, outputs, stream_ptr);
100     return true;
101   }
102 
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)103   bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override {
104     CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_);
105     CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_);
106 
107     PrimitivePtr prim = primitive_;
108     MS_EXCEPTION_IF_NULL(prim);
109 
110     preprocess_collapse_repeated_ = GetValue<bool>(prim->GetAttr("preprocess_collapse_repeated"));
111     ctc_merge_repeated_ = GetValue<bool>(prim->GetAttr("ctc_merge_repeated"));
112     ignore_longer_outputs_than_inputs_ = GetValue<bool>(prim->GetAttr("ignore_longer_outputs_than_inputs"));
113     InitResource();
114     return true;
115   }
116 
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)117   int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override {
118     if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) {
119       return ret;
120     }
121     auto shape_signed = inputs[kPrevOutput0th]->GetShapeVector();
122     auto probs_shape = Convert2SizeTClipNeg(shape_signed);
123     auto indice_dims = inputs[kPrevOutput1st]->GetShapeVector();
124     auto labels_dims = inputs[kPrevOutput2nd]->GetShapeVector();
125     auto sequence_length_dims = inputs[kPrevOutput3rd]->GetShapeVector();
126     is_null_input_ = CHECK_SHAPE_NULL(probs_shape, kernel_name_, "x") ||
127                      CHECK_SHAPE_NULL(indice_dims, kernel_name_, "labels_indices") ||
128                      CHECK_SHAPE_NULL(labels_dims, kernel_name_, "labels_values") ||
129                      CHECK_SHAPE_NULL(sequence_length_dims, kernel_name_, "sequence_length");
130     if (is_null_input_) {
131       InitSizeLists();
132       return true;
133     }
134     if (probs_shape.size() != kProbDimSize) {
135       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of x must be 3, but got " << probs_shape.size();
136     }
137     probs_dims_[kProbDimsIdxForMaxTime] = probs_shape[kProbDimsIdxForMaxTime];
138     probs_dims_[kProbDimsIdxForBatch] = probs_shape[kProbDimsIdxForBatch];
139     probs_dims_[kProbDimsIdxForNumClass] = probs_shape[kProbDimsIdxForNumClass];
140 
141     if (labels_dims.size() != 1) {
142       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of labels_values must be 1, but got "
143                         << labels_dims.size();
144     }
145     if (indice_dims.size() != kIndicesDimSize) {
146       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of labels_indices must be 2, but got "
147                         << indice_dims.size();
148     }
149     label_size_ = sizeof(int);
150     label_size_ *= SizeOf(labels_dims);
151     label_indice_size_ = sizeof(int64_t);
152     label_indice_size_ *= SizeOf(indice_dims);
153 
154     sequence_lengths_size_ = LongToSizeClipNeg(sequence_length_dims[0]) * sizeof(int);
155     InitSizeLists();
156     return KRET_OK;
157   }
158 
159  protected:
LaunchInit(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs)160   void LaunchInit(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
161                   const std::vector<KernelTensor *> &outputs) {
162     probs = GetDeviceAddress<T>(inputs, kInputIdxForProbs);
163     label_indices = GetDeviceAddress<int64_t>(inputs, kInputIdxForLabelIndices);
164     label_values = GetDeviceAddress<int>(inputs, kInputIdxForLabelValues);
165     sequence_length = GetDeviceAddress<int>(inputs, kInputIdxForSeqLen);
166     costs = GetDeviceAddress<T>(outputs, 0);
167     grads = GetDeviceAddress<T>(outputs, 1);
168     softmax_probs = GetDeviceAddress<T>(workspace, kWsIdxForSoftmaxProbs);
169     cum_labels_length = GetDeviceAddress<int>(workspace, kWsIdxForCumLabelLen);
170     label_squence_length = GetDeviceAddress<int>(workspace, kWsIdxForLabelSquenceLen);
171     label_value_sp = GetDeviceAddress<int>(workspace, kWsIdxForLabelValueSp);
172     label_value_pcr = GetDeviceAddress<int>(workspace, kWsIdxForLabelValuePcr);
173     prob_num = GetDeviceAddress<T>(workspace, kWsIdxForProbNum);
174     precum_labels_length = GetDeviceAddress<int>(workspace, kWsIdxForPrecumLabelLen);
175     max_labels_length = GetDeviceAddress<int>(workspace, kWsIdxForMaxLabelLen);
176     numclass = SizeToInt(probs_dims_[kProbDimsIdxForNumClass]);
177     batch = SizeToInt(probs_dims_[kProbDimsIdxForBatch]);
178     max_time = SizeToInt(probs_dims_[kProbDimsIdxForMaxTime]);
179     max_sequence = 0;
180     max_labels_length_host = 0;
181     batch_label = 0;
182     label_value_with_blank = nullptr;
183     log_alpha_b = nullptr;
184     log_beta_b = nullptr;
185   }
186 
LaunchFirstHalf(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)187   void LaunchFirstHalf(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
188                        const std::vector<KernelTensor *> &outputs, void *stream_ptr) {
189     cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
190     cudaError_t status = cudaErrorNotReady;
191     status = CalculateMaxSequence(sequence_length, max_labels_length, batch, stream);
192     CHECK_CUDA_STATUS(status, "CalculateMaxSequence called by " + kernel_name_);
193     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
194       cudaMemcpyAsync(&max_sequence, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
195       "cudaMemcpyAsync failed.");
196     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
197     if (max_time < max_sequence) {
198       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the x[0] must be equal to or greater than max_sequence, "
199                         << "but got x[0]: " << max_time << ", max_sequence: " << max_sequence;
200     }
201     status = InnerSoftMax(probs, softmax_probs, sequence_length, max_time, batch, numclass, stream);
202     CHECK_CUDA_STATUS(status, "InnerSoftMax called by " + kernel_name_);
203     MemsetForWS(label_value_pcr, cum_labels_length, label_squence_length, costs, grads, stream);
204     status = CalculatePreLength(label_squence_length, precum_labels_length, cum_labels_length, max_labels_length,
205                                 label_indices, batch, label_size_ / sizeof(int), stream);
206     CHECK_CUDA_STATUS(status, "CalculatePreLength called by " + kernel_name_);
207     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
208       cudaMemcpyAsync(&batch_label, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
209       "cudaMemcpyAsync failed.");
210     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
211     if (batch != batch_label + 1) {
212       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the batch size of input must be equal to "
213                         << (batch_label + 1) << ", but got " << batch;
214     }
215     status = GenLabelValue(label_value_sp, label_indices, label_values, label_squence_length, cum_labels_length,
216                            max_labels_length, label_size_ / sizeof(int), numclass - 1, batch, stream);
217     CHECK_CUDA_STATUS(status, "GenLabelValue called by " + kernel_name_);
218     if (preprocess_collapse_repeated_) {
219       status = GenLabelValuePCR(label_value_sp, label_value_pcr, label_squence_length, cum_labels_length,
220                                 max_labels_length, batch, stream);
221       CHECK_CUDA_STATUS(status, "GenLabelValuePCR called by " + kernel_name_);
222     }
223     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
224       cudaMemcpyAsync(&max_labels_length_host, max_labels_length, sizeof(int), cudaMemcpyDeviceToHost, stream),
225       "cudaMemcpyAsync failed.");
226     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
227   }
228 
LaunchSecondHalf(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,void * stream_ptr)229   void LaunchSecondHalf(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
230                         const std::vector<KernelTensor *> &outputs, void *stream_ptr) {
231     cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
232     const int SOffSet = 2 * max_labels_length_host + 1;
233     int log_prob_size = batch * SOffSet * max_time;
234     cudaError_t status = cudaErrorNotReady;
235 
236     if (!ignore_longer_outputs_than_inputs_ && max_labels_length_host > max_time) {
237       MS_LOG(EXCEPTION) << "output size is greater than input size.";
238     }
239     MemManageForCus(&log_alpha_b, &log_beta_b, &label_value_with_blank, cum_labels_length, log_prob_size, batch,
240                     stream);
241 
242     if (preprocess_collapse_repeated_) {
243       status = GenLabelWithBlank(label_value_pcr, label_value_with_blank, label_squence_length, precum_labels_length,
244                                  cum_labels_length, batch, numclass - 1, stream);
245     } else {
246       status = GenLabelWithBlank(label_value_sp, label_value_with_blank, label_squence_length, precum_labels_length,
247                                  cum_labels_length, batch, numclass - 1, stream);
248     }
249     CHECK_CUDA_STATUS(status, "GenLabelWithBlank called by " + kernel_name_);
250 
251     status = CalculateFwdVar(log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated_,
252                              batch, SOffSet, max_time, numclass - 1, label_squence_length, cum_labels_length,
253                              ignore_longer_outputs_than_inputs_, stream);
254     CHECK_CUDA_STATUS(status, "CalculateFwdVar called by " + kernel_name_);
255     status = CalculateBwdVar(log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated_,
256                              batch, SOffSet, max_time, numclass - 1, label_squence_length, cum_labels_length,
257                              ignore_longer_outputs_than_inputs_, stream);
258     CHECK_CUDA_STATUS(status, "CalculateBwdVar called by " + kernel_name_);
259     status = CTCLoss(log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, max_time, numclass,
260                      sequence_length, label_squence_length, cum_labels_length, costs, grads, prob_num,
261                      ignore_longer_outputs_than_inputs_, stream);
262     CHECK_CUDA_STATUS(status, kernel_name_);
263     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
264     FreeMem(label_value_with_blank, log_alpha_b, log_beta_b);
265   }
266 
InitSizeLists()267   void InitSizeLists() {
268     workspace_size_list_.clear();
269     output_size_list_.clear();
270     workspace_size_list_.push_back(probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] *
271                                    probs_dims_[kProbDimsIdxForNumClass] * sizeof(T));
272     workspace_size_list_.push_back(sequence_lengths_size_);
273     workspace_size_list_.push_back(sequence_lengths_size_);
274     workspace_size_list_.push_back(label_size_);
275     workspace_size_list_.push_back(label_size_);
276     workspace_size_list_.push_back(probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] *
277                                    probs_dims_[kProbDimsIdxForNumClass] * sizeof(T));
278     workspace_size_list_.push_back(sequence_lengths_size_);
279     workspace_size_list_.push_back(sizeof(int));
280     output_size_list_.push_back(probs_dims_[kProbDimsIdxForBatch] * sizeof(T));
281     output_size_list_.push_back(probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] *
282                                 probs_dims_[kProbDimsIdxForNumClass] * sizeof(T));
283   }
MemsetForWS(int * label_value_pcr,int * cum_labels_length,int * label_squence_length,T * costs,T * grads,cudaStream_t stream)284   void MemsetForWS(int *label_value_pcr, int *cum_labels_length, int *label_squence_length, T *costs, T *grads,
285                    cudaStream_t stream) {
286     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemsetAsync(label_value_pcr, static_cast<int>(0), label_size_, stream),
287                                        "cudaMemSet failed in CtcLossGpuKernelMod::Launch.");
288     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
289       cudaMemsetAsync(cum_labels_length, static_cast<int>(0), sequence_lengths_size_, stream),
290       "cudaMemSet failed in CtcLossGpuKernelMod::Launch.");
291     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
292       cudaMemsetAsync(label_squence_length, static_cast<int>(0), sequence_lengths_size_, stream),
293       "cudaMemSet failed in CtcLossGpuKernelMod::Launch.");
294     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
295       cudaMemsetAsync(costs, static_cast<T>(0), probs_dims_[kProbDimsIdxForBatch] * sizeof(T), stream),
296       "cudaMemSet failed in CtcLossGpuKernelMod::Launch.");
297     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
298       cudaMemsetAsync(grads, static_cast<T>(0),
299                       probs_dims_[kProbDimsIdxForMaxTime] * probs_dims_[kProbDimsIdxForBatch] *
300                         probs_dims_[kProbDimsIdxForNumClass] * sizeof(T),
301                       stream),
302       "cudaMemSet failed in CtcLossGpuKernelMod::Launch.");
303   }
MemManageForCus(T ** log_alpha_b,T ** log_beta_b,int ** label_value_with_blank,int * cum_labels_length,int log_prob_size,int batch,cudaStream_t stream)304   void MemManageForCus(T **log_alpha_b, T **log_beta_b, int **label_value_with_blank, int *cum_labels_length,
305                        int log_prob_size, int batch, cudaStream_t stream) {
306     int total_labels_size_host = 0;
307     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMalloc(reinterpret_cast<void **>(log_alpha_b), sizeof(T) * log_prob_size),
308                                        "cudaMalloc failed.");
309     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMalloc(reinterpret_cast<void **>(log_beta_b), sizeof(T) * log_prob_size),
310                                        "cudaMalloc failed.");
311     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(&total_labels_size_host, cum_labels_length + batch - 1,
312                                                        sizeof(int), cudaMemcpyDeviceToHost, stream),
313                                        "cudaMemcpyAsync failed.");
314     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
315     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
316       cudaMalloc(reinterpret_cast<void **>(label_value_with_blank), sizeof(int) * (2 * total_labels_size_host + batch)),
317       "cudaMalloc failed.");
318   }
319 
FreeMem(int * label_value_with_blank,T * log_alpha_b,T * log_beta_b)320   void FreeMem(int *label_value_with_blank, T *log_alpha_b, T *log_beta_b) {
321     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(label_value_with_blank), "cudaFree failed.");
322     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(log_alpha_b), "cudaFree failed.");
323     CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(log_beta_b), "cudaFree failed.");
324   }
325 
326   size_t probs_dims_[3] = {0};
327   int label_indice_size_;
328   int label_size_;
329   int sequence_lengths_size_;
330   bool preprocess_collapse_repeated_;
331   bool ctc_merge_repeated_;
332   bool ignore_longer_outputs_than_inputs_;
333   bool is_null_input_;
334   std::string kernel_name_;
335   T kLogZero_ = -std::numeric_limits<T>::infinity();
336 
337   // Heap parameter
338   T *probs;
339   int64_t *label_indices;
340   int *label_values;
341   int *sequence_length;
342   T *costs;
343   T *grads;
344   T *softmax_probs;
345   int *cum_labels_length;
346   int *label_squence_length;
347   int *label_value_sp;
348   int *label_value_pcr;
349   T *prob_num;
350   int *precum_labels_length;
351   int *max_labels_length;
352   int numclass;
353   int batch;
354   int max_time;
355   int max_sequence;
356   int max_labels_length_host;
357   int batch_label;
358   int *label_value_with_blank;
359   T *log_alpha_b;
360   T *log_beta_b;
361 };  // namespace kernel
362 }  // namespace kernel
363 }  // namespace mindspore
364 
365 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
366