• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <limits>
18 #include "ctcloss_impl.cuh"
19 template <typename T>
LogSumExp(const T logprob1,const T logprob2)20 __device__ T LogSumExp(const T logprob1, const T logprob2) {
21   if (logprob1 == logprob2 && logprob1 == -std::numeric_limits<T>::infinity()) {
22     return logprob1;
23   } else {
24     return (logprob1 > logprob2) ? logprob1 + log1pf(expf(logprob2 - logprob1))
25                                  : logprob2 + log1pf(expf(logprob1 - logprob2));
26   }
27 }
28 
29 template <typename T>
CalculateFwdVarKernel(T * log_alpha_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs)30 __global__ void CalculateFwdVarKernel(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs,
31                                       const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
32                                       int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
33                                       bool ignore_longer_outputs_than_inputs) {
34   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
35     if (sequence_length[i] == 0 ||
36         (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
37     } else {
38       T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime];
39       int *label_value_with_blank_cur = &label_value_with_blank[0];
40       if (i > 0) {
41         label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
42       }
43       int numclass = blank + 1;
44       int U = 2 * label_squence_length[i] + 1;
45       int Ti = sequence_length[i];
46       int low = 0;
47       int high = 0;
48       log_alpha_b_cur[0] = log(softmax_probs[i * numclass + blank]);
49       int label0 = blank;
50       if (U > 1) {
51         label0 = label_value_with_blank_cur[1];
52         log_alpha_b_cur[maxtime] = log(softmax_probs[i * numclass + label0]);
53       }
54       for (int t = 1; t < Ti; ++t) {
55         low = 0;
56         high = U;
57         int low_limit = U - (2 * (Ti - t));
58         int high_limit = 2 * (t + 1);
59         if (low_limit > low) {
60           low = low_limit;
61         }
62         if (high_limit < U) {
63           high = high_limit;
64         }
65         for (int u = low; u < high; ++u) {
66           T sum_log_alpha = -std::numeric_limits<T>::infinity();
67           if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) {
68             sum_log_alpha = log_alpha_b_cur[u * maxtime + t - 1];
69           }
70           if (u > 0) {
71             sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 1) * maxtime + t - 1]);
72           }
73           if (u > 1) {
74             const bool matching_labels_merge =
75               ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u - 2]);
76             if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) {
77               sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 2) * maxtime + t - 1]);
78             }
79           }
80           log_alpha_b_cur[u * maxtime + t] =
81             log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + t * numclass * batch]) + sum_log_alpha;
82         }
83       }
84     }
85   }
86 }
87 
88 template <typename T>
CalculateBwdVarKernel(T * log_beta_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs)89 __global__ void CalculateBwdVarKernel(T *log_beta_b, int *label_value_with_blank, T *softmax_probs,
90                                       const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
91                                       int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
92                                       bool ignore_longer_outputs_than_inputs) {
93   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
94     if (sequence_length[i] == 0 ||
95         (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
96     } else {
97       T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime];
98       int *label_value_with_blank_cur = &label_value_with_blank[0];
99       if (i > 0) {
100         label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
101       }
102       int numclass = blank + 1;
103       int U = 2 * label_squence_length[i] + 1;
104       int Ti = sequence_length[i];
105       int low = 0;
106       int high = 0;
107       if (U > 1) {
108         for (int u = U - 2; u < U; ++u) {
109           log_beta_b_cur[u * maxtime + Ti - 1] = 0;
110         }
111       } else {
112         log_beta_b_cur[Ti - 1] = 0;
113         log_beta_b_cur[Ti - 2] = 0;
114       }
115       for (int t = Ti - 2; t >= 0; --t) {
116         low = 0;
117         high = U;
118         int low_limit = U - (2 * (Ti - t));
119         int high_limit = 2 * (t + 1);
120         if (low_limit > low) {
121           low = low_limit;
122         }
123         if (high_limit < U) {
124           high = high_limit;
125         }
126         for (int u = low; u < high; ++u) {
127           if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) {
128             log_beta_b_cur[u * maxtime + t] = LogSumExp(
129               log_beta_b_cur[u * maxtime + t],
130               log_beta_b_cur[u * maxtime + t + 1] +
131                 log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + (t + 1) * numclass * batch]));
132           }
133           if (u + 1 < U) {
134             log_beta_b_cur[u * maxtime + t] = LogSumExp(
135               log_beta_b_cur[u * maxtime + t],
136               log_beta_b_cur[(u + 1) * maxtime + t + 1] +
137                 log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 1] + (t + 1) * numclass * batch]));
138           }
139           if (u + 2 < U) {
140             const bool matching_labels_merge =
141               ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u + 2]);
142             if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) {
143               log_beta_b_cur[u * maxtime + t] = LogSumExp(
144                 log_beta_b_cur[u * maxtime + t],
145                 log_beta_b_cur[(u + 2) * maxtime + t + 1] +
146                   log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 2] + (t + 1) * numclass * batch]));
147             }
148           }
149         }
150       }
151     }
152   }
153 }
154 
155 template <typename T>
ProbInitKernel(T * prob_num,int size)156 __global__ void ProbInitKernel(T *prob_num, int size) {
157   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
158     prob_num[i] = -std::numeric_limits<T>::infinity();
159   }
160 }
161 template <typename T>
LogBInitKernel(T * log_b,int log_prob_size)162 __global__ void LogBInitKernel(T *log_b, int log_prob_size) {
163   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < log_prob_size; i += blockDim.x * gridDim.x) {
164     log_b[i] = -std::numeric_limits<T>::infinity();
165   }
166 }
167 
168 template <typename T>
CTCLossKernel(T * log_alpha_b,T * log_beta_b,T * softmax_probs,int * label_value_with_blank,int batch,int SOffSet,int maxtime,int numclass,const int * sequence_length,int * label_squence_length,int * cum_labels_length,T * cost,T * grads,T * prob_num,bool ignore_longer_outputs_than_inputs)169 __global__ void CTCLossKernel(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch,
170                               int SOffSet, int maxtime, int numclass, const int *sequence_length,
171                               int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num,
172                               bool ignore_longer_outputs_than_inputs) {
173   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
174     if (sequence_length[i] == 0 ||
175         (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
176     } else {
177       T *grad_cur = &grads[i * numclass];
178       const T *softmax_probs_cur = &softmax_probs[i * numclass];
179       T *prob_num_cur = &prob_num[i * numclass];
180       int U = 2 * label_squence_length[i] + 1;
181       T log_pzx = -std::numeric_limits<T>::infinity();
182       const T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime];
183       const T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime];
184       int *label_value_with_blank_cur = &label_value_with_blank[0];
185       if (i > 0) {
186         label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
187       }
188       for (int u = 0; u < U; ++u) {
189         log_pzx = LogSumExp(log_pzx, log_alpha_b_cur[u * maxtime] + log_beta_b_cur[u * maxtime]);
190       }
191       cost[i] = -log_pzx;
192       // grad
193       int L = numclass;
194       int Ti = sequence_length[i];
195       if (log_pzx == -std::numeric_limits<T>::infinity()) {
196         for (int t = 0; t < Ti; ++t) {
197           for (int l = 0; l < L; ++l) {
198             grad_cur[t * numclass * batch + l] = softmax_probs_cur[t * numclass * batch + l];
199           }
200         }
201       } else {
202         for (int t = 0; t < Ti; ++t) {
203           for (int u = 0; u < U; ++u) {
204             int l = label_value_with_blank_cur[u];
205             prob_num_cur[t * batch * numclass + l] =
206               LogSumExp(prob_num_cur[t * batch * numclass + l],
207                         log_alpha_b_cur[u * maxtime + t] + log_beta_b_cur[u * maxtime + t]);
208           }
209           for (int l = 0; l < L; ++l) {
210             grad_cur[t * numclass * batch + l] =
211               softmax_probs_cur[t * numclass * batch + l] - expf(prob_num_cur[t * batch * numclass + l] - log_pzx);
212           }
213         }
214       }
215     }
216   }
217 }
218 
219 template <typename T>
InnerSoftMaxKernel(const T * probs,T * softmax_probs,const int * sequence_length,int max_time,int batch,int numclass)220 __global__ void InnerSoftMaxKernel(const T *probs, T *softmax_probs, const int *sequence_length, int max_time,
221                                    int batch, int numclass) {
222   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch * max_time; i += blockDim.x * gridDim.x) {
223     int k = i / batch;
224     int m = i % batch;
225     if (k < sequence_length[m]) {
226       T maxCoeff = 0.;
227       T sumCoeff = 0.;
228       for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
229         if (probs[j] > maxCoeff) {
230           maxCoeff = probs[j];
231         }
232       }
233       for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
234         sumCoeff += exp(probs[j] - maxCoeff);
235         softmax_probs[j] = exp(probs[j] - maxCoeff);
236       }
237       for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
238         softmax_probs[j] /= sumCoeff;
239       }
240     }
241   }
242 }
243 
GenLabelValuePCRKernel(int * label_value_sp,int * label_value_pcr,int * label_squence_length,int * cum_labels_length,int batch)244 __global__ void GenLabelValuePCRKernel(int *label_value_sp, int *label_value_pcr, int *label_squence_length,
245                                        int *cum_labels_length, int batch) {
246   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
247     int L = label_squence_length[i];
248     label_squence_length[i] = 0;
249     int offset = 0;
250     if (i > 0) {
251       offset = cum_labels_length[i - 1];
252     }
253     for (int l = offset; l < L; ++l) {
254       if (l == offset || label_value_sp[l] != label_value_sp[l - 1]) {
255         label_value_pcr[offset + label_squence_length[i]++] = label_value_sp[l];
256       }
257     }
258   }
259 }
260 
UpdateLengthKernel(int * label_squence_length,int * cum_labels_length,int * max_labels_length,int batch)261 __global__ void UpdateLengthKernel(int *label_squence_length, int *cum_labels_length, int *max_labels_length,
262                                    int batch) {
263   max_labels_length[0] = 0;
264   for (int i = 0; i < batch; ++i) {
265     if (label_squence_length[i] > max_labels_length[0]) {
266       max_labels_length[0] = label_squence_length[i];
267     }
268     if (i == 0) {
269       cum_labels_length[i] = label_squence_length[i];
270     } else {
271       cum_labels_length[i] = label_squence_length[i] + cum_labels_length[i - 1];
272     }
273   }
274 }
275 
276 template <typename T>
CalculateBwdVar(T * log_beta_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs,cudaStream_t stream)277 cudaError_t CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
278                             bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank,
279                             int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs,
280                             cudaStream_t stream) {
281   int log_prob_size = SOffSet * batch * maxtime;
282   LogBInitKernel<<<GET_BLOCKS(log_prob_size), GET_THREADS, 0, stream>>>(log_beta_b, log_prob_size);
283   CalculateBwdVarKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
284     log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime,
285     blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs);
286   return GetCudaStatus();
287 }
288 
289 template <typename T>
CalculateFwdVar(T * log_alpha_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs,cudaStream_t stream)290 cudaError_t CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
291                             bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank,
292                             int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs,
293                             cudaStream_t stream) {
294   int log_prob_size = SOffSet * batch * maxtime;
295   LogBInitKernel<<<GET_BLOCKS(log_prob_size), GET_THREADS, 0, stream>>>(log_alpha_b, log_prob_size);
296   CalculateFwdVarKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
297     log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime,
298     blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs);
299   return GetCudaStatus();
300 }
301 
302 template <typename T>
InnerSoftMax(const T * probs,T * softmax_probs,const int * sequence_length,int max_time,int batch,int numclass,cudaStream_t stream)303 cudaError_t InnerSoftMax(const T *probs, T *softmax_probs, const int *sequence_length, int max_time, int batch,
304                          int numclass, cudaStream_t stream) {
305   InnerSoftMaxKernel<<<GET_BLOCKS(batch * max_time), GET_THREADS, 0, stream>>>(probs, softmax_probs, sequence_length,
306                                                                                max_time, batch, numclass);
307   return GetCudaStatus();
308 }
309 
GenLabelWithBlankKernel(int * label_value,int * label_value_with_blank,int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int batch,int blank)310 __global__ void GenLabelWithBlankKernel(int *label_value, int *label_value_with_blank, int *label_squence_length,
311                                         int *precum_labels_length, int *cum_labels_length, int batch, int blank) {
312   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
313     int offset = 0;
314     int offset1 = 0;
315     if (i > 0) {
316       offset = 2 * cum_labels_length[i - 1] + i;
317       offset1 = precum_labels_length[i - 1];
318     }
319     for (int j = 0; j < label_squence_length[i]; ++j) {
320       label_value_with_blank[offset + 2 * j] = blank;
321       label_value_with_blank[offset + 2 * j + 1] = label_value[offset1 + j];
322     }
323     label_value_with_blank[offset + 2 * label_squence_length[i]] = blank;
324   }
325 }
326 
GenLabelWithBlank(int * label_value,int * label_value_with_blank,int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int batch,int blank,cudaStream_t stream)327 cudaError_t GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length,
328                               int *precum_labels_length, int *cum_labels_length, int batch, int blank,
329                               cudaStream_t stream) {
330   GenLabelWithBlankKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
331     label_value, label_value_with_blank, label_squence_length, precum_labels_length, cum_labels_length, batch, blank);
332   return GetCudaStatus();
333 }
334 
GenLabelValuePCR(int * label_value_sp,int * label_value_pcr,int * label_squence_length,int * cum_labels_length,int * max_labels_length,int batch,cudaStream_t stream)335 cudaError_t GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length,
336                              int *cum_labels_length, int *max_labels_length, int batch, cudaStream_t stream) {
337   GenLabelValuePCRKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(label_value_sp, label_value_pcr,
338                                                                         label_squence_length, cum_labels_length, batch);
339   UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch);
340   return GetCudaStatus();
341 }
342 
GenLabelValueKernel(int * label_value_sp,const int64_t * label_indices,const int * label_values,int * label_squence_length,int * cum_labels_length,int size)343 __global__ void GenLabelValueKernel(int *label_value_sp, const int64_t *label_indices, const int *label_values,
344                                     int *label_squence_length, int *cum_labels_length, int size) {
345   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
346     int64_t b = label_indices[i * 2];
347     int offset = 0;
348     if (b > 0) {
349       offset = cum_labels_length[b - 1];
350     }
351     int64_t index = offset + label_indices[i * 2 + 1];
352     label_value_sp[index] = label_values[i];
353   }
354 }
LabelValueInitKernel(int * label_value_sp,int size,int blank)355 __global__ void LabelValueInitKernel(int *label_value_sp, int size, int blank) {
356   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
357     label_value_sp[i] = blank;
358   }
359 }
RecalculateLengthKernel(int * label_value_sp,int * label_squence_length,int * cum_labels_length,int batch,int blank)360 __global__ void RecalculateLengthKernel(int *label_value_sp, int *label_squence_length, int *cum_labels_length,
361                                         int batch, int blank) {
362   for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
363     int offset = 0;
364     if (i > 0) {
365       offset = cum_labels_length[i - 1];
366     }
367     int L = label_squence_length[i];
368     label_squence_length[i] = 0;
369     for (int j = offset; j < offset + L; ++j) {
370       if (label_value_sp[j] >= blank) {
371         break;
372       } else {
373         label_squence_length[i]++;
374       }
375     }
376   }
377 }
GenLabelValue(int * label_value_sp,const int64_t * label_indices,const int * label_values,int * label_squence_length,int * cum_labels_length,int * max_labels_length,int size,int blank,int batch,cudaStream_t stream)378 cudaError_t GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values,
379                           int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size,
380                           int blank, int batch, cudaStream_t stream) {
381   LabelValueInitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(label_value_sp, size, blank);
382   GenLabelValueKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(label_value_sp, label_indices, label_values,
383                                                                     label_squence_length, cum_labels_length, size);
384   RecalculateLengthKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(label_value_sp, label_squence_length,
385                                                                          cum_labels_length, batch, blank);
386   UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch);
387   return GetCudaStatus();
388 }
389 
CalculatePreLengthKernel(int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int * max_labels_length,const int64_t * label_indices,int batch,int size)390 __global__ void CalculatePreLengthKernel(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
391                                          int *max_labels_length, const int64_t *label_indices, int batch, int size) {
392   max_labels_length[0] = 0;
393   for (int i = 0; i < size; ++i) {
394     label_squence_length[label_indices[i * 2]]++;
395     if (max_labels_length[0] < label_indices[i * 2]) {
396       max_labels_length[0] = label_indices[i * 2];
397     }
398   }
399   precum_labels_length[0] = label_squence_length[0];
400   cum_labels_length[0] = label_squence_length[0];
401   for (int i = 1; i < batch; ++i) {
402     cum_labels_length[i] = cum_labels_length[i - 1] + label_squence_length[i];
403     precum_labels_length[i] = precum_labels_length[i - 1] + label_squence_length[i];
404   }
405 }
406 
CalculateMaxSequenceKernel(const int * sequence_length,int * max_labels_length,int batch)407 __global__ void CalculateMaxSequenceKernel(const int *sequence_length, int *max_labels_length, int batch) {
408   max_labels_length[0] = 0;
409   for (int i = 0; i < batch; ++i) {
410     if (sequence_length[i] > max_labels_length[0]) {
411       max_labels_length[0] = sequence_length[i];
412     }
413   }
414 }
415 
CalculateMaxSequence(const int * sequence_length,int * max_labels_length,int batch,cudaStream_t stream)416 cudaError_t CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream) {
417   CalculateMaxSequenceKernel<<<1, 1, 0, stream>>>(sequence_length, max_labels_length, batch);
418   return GetCudaStatus();
419 }
420 
CalculatePreLength(int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int * max_labels_length,const int64_t * label_indices,int batch,int size,cudaStream_t stream)421 cudaError_t CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
422                                int *max_labels_length, const int64_t *label_indices, int batch, int size,
423                                cudaStream_t stream) {
424   CalculatePreLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, precum_labels_length, cum_labels_length,
425                                                 max_labels_length, label_indices, batch, size);
426   return GetCudaStatus();
427 }
428 
429 template <typename T>
CTCLoss(T * log_alpha_b,T * log_beta_b,T * softmax_probs,int * label_value_with_blank,int batch,int SOffSet,int maxtime,int numclass,const int * sequence_length,int * label_squence_length,int * cum_labels_length,T * cost,T * grads,T * prob_num,bool ignore_longer_outputs_than_inputs,cudaStream_t stream)430 cudaError_t CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch,
431                     int SOffSet, int maxtime, int numclass, const int *sequence_length, int *label_squence_length,
432                     int *cum_labels_length, T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs,
433                     cudaStream_t stream) {
434   ProbInitKernel<<<GET_BLOCKS(maxtime * batch * numclass), GET_THREADS, 0, stream>>>(prob_num,
435                                                                                      maxtime * batch * numclass);
436   CTCLossKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
437     log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, maxtime, numclass, sequence_length,
438     label_squence_length, cum_labels_length, cost, grads, prob_num, ignore_longer_outputs_than_inputs);
439   return GetCudaStatus();
440 }
441 
442 template CUDA_LIB_EXPORT cudaError_t CalculateFwdVar<float>(
443   float *log_alpha_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length,
444   bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
445   int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
446 
447 template CUDA_LIB_EXPORT cudaError_t CalculateBwdVar<float>(
448   float *log_beta_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length,
449   bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
450   int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
451 
452 template CUDA_LIB_EXPORT cudaError_t InnerSoftMax<float>(const float *probs, float *softmax_probs,
453                                                          const int *sequence_length, int max_time, int batch,
454                                                          int numclass, cudaStream_t stream);
455 
456 template CUDA_LIB_EXPORT cudaError_t CTCLoss<float>(float *log_alpha_b, float *log_beta_b, float *softmax_probs,
457                                                     int *label_value_with_blank, int batch, int SOffSet, int maxtime,
458                                                     int numclass, const int *sequence_length, int *label_squence_length,
459                                                     int *cum_labels_length, float *cost, float *grads, float *prob_num,
460                                                     bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
461