1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <limits>
18 #include "ctcloss_impl.cuh"
19 template <typename T>
LogSumExp(const T logprob1,const T logprob2)20 __device__ T LogSumExp(const T logprob1, const T logprob2) {
21 if (logprob1 == logprob2 && logprob1 == -std::numeric_limits<T>::infinity()) {
22 return logprob1;
23 } else {
24 return (logprob1 > logprob2) ? logprob1 + log1pf(expf(logprob2 - logprob1))
25 : logprob2 + log1pf(expf(logprob1 - logprob2));
26 }
27 }
28
29 template <typename T>
CalculateFwdVarKernel(T * log_alpha_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs)30 __global__ void CalculateFwdVarKernel(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs,
31 const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
32 int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
33 bool ignore_longer_outputs_than_inputs) {
34 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
35 if (sequence_length[i] == 0 ||
36 (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
37 } else {
38 T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime];
39 int *label_value_with_blank_cur = &label_value_with_blank[0];
40 if (i > 0) {
41 label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
42 }
43 int numclass = blank + 1;
44 int U = 2 * label_squence_length[i] + 1;
45 int Ti = sequence_length[i];
46 int low = 0;
47 int high = 0;
48 log_alpha_b_cur[0] = log(softmax_probs[i * numclass + blank]);
49 int label0 = blank;
50 if (U > 1) {
51 label0 = label_value_with_blank_cur[1];
52 log_alpha_b_cur[maxtime] = log(softmax_probs[i * numclass + label0]);
53 }
54 for (int t = 1; t < Ti; ++t) {
55 low = 0;
56 high = U;
57 int low_limit = U - (2 * (Ti - t));
58 int high_limit = 2 * (t + 1);
59 if (low_limit > low) {
60 low = low_limit;
61 }
62 if (high_limit < U) {
63 high = high_limit;
64 }
65 for (int u = low; u < high; ++u) {
66 T sum_log_alpha = -std::numeric_limits<T>::infinity();
67 if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) {
68 sum_log_alpha = log_alpha_b_cur[u * maxtime + t - 1];
69 }
70 if (u > 0) {
71 sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 1) * maxtime + t - 1]);
72 }
73 if (u > 1) {
74 const bool matching_labels_merge =
75 ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u - 2]);
76 if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) {
77 sum_log_alpha = LogSumExp(sum_log_alpha, log_alpha_b_cur[(u - 2) * maxtime + t - 1]);
78 }
79 }
80 log_alpha_b_cur[u * maxtime + t] =
81 log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + t * numclass * batch]) + sum_log_alpha;
82 }
83 }
84 }
85 }
86 }
87
88 template <typename T>
CalculateBwdVarKernel(T * log_beta_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs)89 __global__ void CalculateBwdVarKernel(T *log_beta_b, int *label_value_with_blank, T *softmax_probs,
90 const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
91 int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
92 bool ignore_longer_outputs_than_inputs) {
93 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
94 if (sequence_length[i] == 0 ||
95 (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
96 } else {
97 T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime];
98 int *label_value_with_blank_cur = &label_value_with_blank[0];
99 if (i > 0) {
100 label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
101 }
102 int numclass = blank + 1;
103 int U = 2 * label_squence_length[i] + 1;
104 int Ti = sequence_length[i];
105 int low = 0;
106 int high = 0;
107 if (U > 1) {
108 for (int u = U - 2; u < U; ++u) {
109 log_beta_b_cur[u * maxtime + Ti - 1] = 0;
110 }
111 } else {
112 log_beta_b_cur[Ti - 1] = 0;
113 log_beta_b_cur[Ti - 2] = 0;
114 }
115 for (int t = Ti - 2; t >= 0; --t) {
116 low = 0;
117 high = U;
118 int low_limit = U - (2 * (Ti - t));
119 int high_limit = 2 * (t + 1);
120 if (low_limit > low) {
121 low = low_limit;
122 }
123 if (high_limit < U) {
124 high = high_limit;
125 }
126 for (int u = low; u < high; ++u) {
127 if (ctc_merge_repeated || label_value_with_blank_cur[u] == blank) {
128 log_beta_b_cur[u * maxtime + t] = LogSumExp(
129 log_beta_b_cur[u * maxtime + t],
130 log_beta_b_cur[u * maxtime + t + 1] +
131 log(softmax_probs[i * numclass + label_value_with_blank_cur[u] + (t + 1) * numclass * batch]));
132 }
133 if (u + 1 < U) {
134 log_beta_b_cur[u * maxtime + t] = LogSumExp(
135 log_beta_b_cur[u * maxtime + t],
136 log_beta_b_cur[(u + 1) * maxtime + t + 1] +
137 log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 1] + (t + 1) * numclass * batch]));
138 }
139 if (u + 2 < U) {
140 const bool matching_labels_merge =
141 ctc_merge_repeated && (label_value_with_blank_cur[u] == label_value_with_blank_cur[u + 2]);
142 if (label_value_with_blank_cur[u] != blank && !matching_labels_merge) {
143 log_beta_b_cur[u * maxtime + t] = LogSumExp(
144 log_beta_b_cur[u * maxtime + t],
145 log_beta_b_cur[(u + 2) * maxtime + t + 1] +
146 log(softmax_probs[i * numclass + label_value_with_blank_cur[u + 2] + (t + 1) * numclass * batch]));
147 }
148 }
149 }
150 }
151 }
152 }
153 }
154
155 template <typename T>
ProbInitKernel(T * prob_num,int size)156 __global__ void ProbInitKernel(T *prob_num, int size) {
157 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
158 prob_num[i] = -std::numeric_limits<T>::infinity();
159 }
160 }
161 template <typename T>
LogBInitKernel(T * log_b,int log_prob_size)162 __global__ void LogBInitKernel(T *log_b, int log_prob_size) {
163 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < log_prob_size; i += blockDim.x * gridDim.x) {
164 log_b[i] = -std::numeric_limits<T>::infinity();
165 }
166 }
167
168 template <typename T>
CTCLossKernel(T * log_alpha_b,T * log_beta_b,T * softmax_probs,int * label_value_with_blank,int batch,int SOffSet,int maxtime,int numclass,const int * sequence_length,int * label_squence_length,int * cum_labels_length,T * cost,T * grads,T * prob_num,bool ignore_longer_outputs_than_inputs)169 __global__ void CTCLossKernel(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch,
170 int SOffSet, int maxtime, int numclass, const int *sequence_length,
171 int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num,
172 bool ignore_longer_outputs_than_inputs) {
173 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
174 if (sequence_length[i] == 0 ||
175 (ignore_longer_outputs_than_inputs && label_squence_length[i] > sequence_length[i])) {
176 } else {
177 T *grad_cur = &grads[i * numclass];
178 const T *softmax_probs_cur = &softmax_probs[i * numclass];
179 T *prob_num_cur = &prob_num[i * numclass];
180 int U = 2 * label_squence_length[i] + 1;
181 T log_pzx = -std::numeric_limits<T>::infinity();
182 const T *log_alpha_b_cur = &log_alpha_b[i * SOffSet * maxtime];
183 const T *log_beta_b_cur = &log_beta_b[i * SOffSet * maxtime];
184 int *label_value_with_blank_cur = &label_value_with_blank[0];
185 if (i > 0) {
186 label_value_with_blank_cur = &label_value_with_blank[2 * cum_labels_length[i - 1] + i];
187 }
188 for (int u = 0; u < U; ++u) {
189 log_pzx = LogSumExp(log_pzx, log_alpha_b_cur[u * maxtime] + log_beta_b_cur[u * maxtime]);
190 }
191 cost[i] = -log_pzx;
192 // grad
193 int L = numclass;
194 int Ti = sequence_length[i];
195 if (log_pzx == -std::numeric_limits<T>::infinity()) {
196 for (int t = 0; t < Ti; ++t) {
197 for (int l = 0; l < L; ++l) {
198 grad_cur[t * numclass * batch + l] = softmax_probs_cur[t * numclass * batch + l];
199 }
200 }
201 } else {
202 for (int t = 0; t < Ti; ++t) {
203 for (int u = 0; u < U; ++u) {
204 int l = label_value_with_blank_cur[u];
205 prob_num_cur[t * batch * numclass + l] =
206 LogSumExp(prob_num_cur[t * batch * numclass + l],
207 log_alpha_b_cur[u * maxtime + t] + log_beta_b_cur[u * maxtime + t]);
208 }
209 for (int l = 0; l < L; ++l) {
210 grad_cur[t * numclass * batch + l] =
211 softmax_probs_cur[t * numclass * batch + l] - expf(prob_num_cur[t * batch * numclass + l] - log_pzx);
212 }
213 }
214 }
215 }
216 }
217 }
218
219 template <typename T>
InnerSoftMaxKernel(const T * probs,T * softmax_probs,const int * sequence_length,int max_time,int batch,int numclass)220 __global__ void InnerSoftMaxKernel(const T *probs, T *softmax_probs, const int *sequence_length, int max_time,
221 int batch, int numclass) {
222 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch * max_time; i += blockDim.x * gridDim.x) {
223 int k = i / batch;
224 int m = i % batch;
225 if (k < sequence_length[m]) {
226 T maxCoeff = 0.;
227 T sumCoeff = 0.;
228 for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
229 if (probs[j] > maxCoeff) {
230 maxCoeff = probs[j];
231 }
232 }
233 for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
234 sumCoeff += exp(probs[j] - maxCoeff);
235 softmax_probs[j] = exp(probs[j] - maxCoeff);
236 }
237 for (int j = i * numclass; j < (i + 1) * numclass; ++j) {
238 softmax_probs[j] /= sumCoeff;
239 }
240 }
241 }
242 }
243
GenLabelValuePCRKernel(int * label_value_sp,int * label_value_pcr,int * label_squence_length,int * cum_labels_length,int batch)244 __global__ void GenLabelValuePCRKernel(int *label_value_sp, int *label_value_pcr, int *label_squence_length,
245 int *cum_labels_length, int batch) {
246 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
247 int L = label_squence_length[i];
248 label_squence_length[i] = 0;
249 int offset = 0;
250 if (i > 0) {
251 offset = cum_labels_length[i - 1];
252 }
253 for (int l = offset; l < L; ++l) {
254 if (l == offset || label_value_sp[l] != label_value_sp[l - 1]) {
255 label_value_pcr[offset + label_squence_length[i]++] = label_value_sp[l];
256 }
257 }
258 }
259 }
260
UpdateLengthKernel(int * label_squence_length,int * cum_labels_length,int * max_labels_length,int batch)261 __global__ void UpdateLengthKernel(int *label_squence_length, int *cum_labels_length, int *max_labels_length,
262 int batch) {
263 max_labels_length[0] = 0;
264 for (int i = 0; i < batch; ++i) {
265 if (label_squence_length[i] > max_labels_length[0]) {
266 max_labels_length[0] = label_squence_length[i];
267 }
268 if (i == 0) {
269 cum_labels_length[i] = label_squence_length[i];
270 } else {
271 cum_labels_length[i] = label_squence_length[i] + cum_labels_length[i - 1];
272 }
273 }
274 }
275
276 template <typename T>
CalculateBwdVar(T * log_beta_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs,cudaStream_t stream)277 cudaError_t CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
278 bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank,
279 int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs,
280 cudaStream_t stream) {
281 int log_prob_size = SOffSet * batch * maxtime;
282 LogBInitKernel<<<GET_BLOCKS(log_prob_size), GET_THREADS, 0, stream>>>(log_beta_b, log_prob_size);
283 CalculateBwdVarKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
284 log_beta_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime,
285 blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs);
286 return GetCudaStatus();
287 }
288
289 template <typename T>
CalculateFwdVar(T * log_alpha_b,int * label_value_with_blank,T * softmax_probs,const int * sequence_length,bool ctc_merge_repeated,int batch,int SOffSet,int maxtime,int blank,int * label_squence_length,int * cum_labels_length,bool ignore_longer_outputs_than_inputs,cudaStream_t stream)290 cudaError_t CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs, const int *sequence_length,
291 bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank,
292 int *label_squence_length, int *cum_labels_length, bool ignore_longer_outputs_than_inputs,
293 cudaStream_t stream) {
294 int log_prob_size = SOffSet * batch * maxtime;
295 LogBInitKernel<<<GET_BLOCKS(log_prob_size), GET_THREADS, 0, stream>>>(log_alpha_b, log_prob_size);
296 CalculateFwdVarKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
297 log_alpha_b, label_value_with_blank, softmax_probs, sequence_length, ctc_merge_repeated, batch, SOffSet, maxtime,
298 blank, label_squence_length, cum_labels_length, ignore_longer_outputs_than_inputs);
299 return GetCudaStatus();
300 }
301
302 template <typename T>
InnerSoftMax(const T * probs,T * softmax_probs,const int * sequence_length,int max_time,int batch,int numclass,cudaStream_t stream)303 cudaError_t InnerSoftMax(const T *probs, T *softmax_probs, const int *sequence_length, int max_time, int batch,
304 int numclass, cudaStream_t stream) {
305 InnerSoftMaxKernel<<<GET_BLOCKS(batch * max_time), GET_THREADS, 0, stream>>>(probs, softmax_probs, sequence_length,
306 max_time, batch, numclass);
307 return GetCudaStatus();
308 }
309
GenLabelWithBlankKernel(int * label_value,int * label_value_with_blank,int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int batch,int blank)310 __global__ void GenLabelWithBlankKernel(int *label_value, int *label_value_with_blank, int *label_squence_length,
311 int *precum_labels_length, int *cum_labels_length, int batch, int blank) {
312 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
313 int offset = 0;
314 int offset1 = 0;
315 if (i > 0) {
316 offset = 2 * cum_labels_length[i - 1] + i;
317 offset1 = precum_labels_length[i - 1];
318 }
319 for (int j = 0; j < label_squence_length[i]; ++j) {
320 label_value_with_blank[offset + 2 * j] = blank;
321 label_value_with_blank[offset + 2 * j + 1] = label_value[offset1 + j];
322 }
323 label_value_with_blank[offset + 2 * label_squence_length[i]] = blank;
324 }
325 }
326
GenLabelWithBlank(int * label_value,int * label_value_with_blank,int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int batch,int blank,cudaStream_t stream)327 cudaError_t GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length,
328 int *precum_labels_length, int *cum_labels_length, int batch, int blank,
329 cudaStream_t stream) {
330 GenLabelWithBlankKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
331 label_value, label_value_with_blank, label_squence_length, precum_labels_length, cum_labels_length, batch, blank);
332 return GetCudaStatus();
333 }
334
GenLabelValuePCR(int * label_value_sp,int * label_value_pcr,int * label_squence_length,int * cum_labels_length,int * max_labels_length,int batch,cudaStream_t stream)335 cudaError_t GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length,
336 int *cum_labels_length, int *max_labels_length, int batch, cudaStream_t stream) {
337 GenLabelValuePCRKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(label_value_sp, label_value_pcr,
338 label_squence_length, cum_labels_length, batch);
339 UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch);
340 return GetCudaStatus();
341 }
342
GenLabelValueKernel(int * label_value_sp,const int64_t * label_indices,const int * label_values,int * label_squence_length,int * cum_labels_length,int size)343 __global__ void GenLabelValueKernel(int *label_value_sp, const int64_t *label_indices, const int *label_values,
344 int *label_squence_length, int *cum_labels_length, int size) {
345 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
346 int64_t b = label_indices[i * 2];
347 int offset = 0;
348 if (b > 0) {
349 offset = cum_labels_length[b - 1];
350 }
351 int64_t index = offset + label_indices[i * 2 + 1];
352 label_value_sp[index] = label_values[i];
353 }
354 }
LabelValueInitKernel(int * label_value_sp,int size,int blank)355 __global__ void LabelValueInitKernel(int *label_value_sp, int size, int blank) {
356 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
357 label_value_sp[i] = blank;
358 }
359 }
RecalculateLengthKernel(int * label_value_sp,int * label_squence_length,int * cum_labels_length,int batch,int blank)360 __global__ void RecalculateLengthKernel(int *label_value_sp, int *label_squence_length, int *cum_labels_length,
361 int batch, int blank) {
362 for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch; i += blockDim.x * gridDim.x) {
363 int offset = 0;
364 if (i > 0) {
365 offset = cum_labels_length[i - 1];
366 }
367 int L = label_squence_length[i];
368 label_squence_length[i] = 0;
369 for (int j = offset; j < offset + L; ++j) {
370 if (label_value_sp[j] >= blank) {
371 break;
372 } else {
373 label_squence_length[i]++;
374 }
375 }
376 }
377 }
GenLabelValue(int * label_value_sp,const int64_t * label_indices,const int * label_values,int * label_squence_length,int * cum_labels_length,int * max_labels_length,int size,int blank,int batch,cudaStream_t stream)378 cudaError_t GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values,
379 int *label_squence_length, int *cum_labels_length, int *max_labels_length, int size,
380 int blank, int batch, cudaStream_t stream) {
381 LabelValueInitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(label_value_sp, size, blank);
382 GenLabelValueKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(label_value_sp, label_indices, label_values,
383 label_squence_length, cum_labels_length, size);
384 RecalculateLengthKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(label_value_sp, label_squence_length,
385 cum_labels_length, batch, blank);
386 UpdateLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, cum_labels_length, max_labels_length, batch);
387 return GetCudaStatus();
388 }
389
CalculatePreLengthKernel(int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int * max_labels_length,const int64_t * label_indices,int batch,int size)390 __global__ void CalculatePreLengthKernel(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
391 int *max_labels_length, const int64_t *label_indices, int batch, int size) {
392 max_labels_length[0] = 0;
393 for (int i = 0; i < size; ++i) {
394 label_squence_length[label_indices[i * 2]]++;
395 if (max_labels_length[0] < label_indices[i * 2]) {
396 max_labels_length[0] = label_indices[i * 2];
397 }
398 }
399 precum_labels_length[0] = label_squence_length[0];
400 cum_labels_length[0] = label_squence_length[0];
401 for (int i = 1; i < batch; ++i) {
402 cum_labels_length[i] = cum_labels_length[i - 1] + label_squence_length[i];
403 precum_labels_length[i] = precum_labels_length[i - 1] + label_squence_length[i];
404 }
405 }
406
CalculateMaxSequenceKernel(const int * sequence_length,int * max_labels_length,int batch)407 __global__ void CalculateMaxSequenceKernel(const int *sequence_length, int *max_labels_length, int batch) {
408 max_labels_length[0] = 0;
409 for (int i = 0; i < batch; ++i) {
410 if (sequence_length[i] > max_labels_length[0]) {
411 max_labels_length[0] = sequence_length[i];
412 }
413 }
414 }
415
CalculateMaxSequence(const int * sequence_length,int * max_labels_length,int batch,cudaStream_t stream)416 cudaError_t CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch, cudaStream_t stream) {
417 CalculateMaxSequenceKernel<<<1, 1, 0, stream>>>(sequence_length, max_labels_length, batch);
418 return GetCudaStatus();
419 }
420
CalculatePreLength(int * label_squence_length,int * precum_labels_length,int * cum_labels_length,int * max_labels_length,const int64_t * label_indices,int batch,int size,cudaStream_t stream)421 cudaError_t CalculatePreLength(int *label_squence_length, int *precum_labels_length, int *cum_labels_length,
422 int *max_labels_length, const int64_t *label_indices, int batch, int size,
423 cudaStream_t stream) {
424 CalculatePreLengthKernel<<<1, 1, 0, stream>>>(label_squence_length, precum_labels_length, cum_labels_length,
425 max_labels_length, label_indices, batch, size);
426 return GetCudaStatus();
427 }
428
429 template <typename T>
CTCLoss(T * log_alpha_b,T * log_beta_b,T * softmax_probs,int * label_value_with_blank,int batch,int SOffSet,int maxtime,int numclass,const int * sequence_length,int * label_squence_length,int * cum_labels_length,T * cost,T * grads,T * prob_num,bool ignore_longer_outputs_than_inputs,cudaStream_t stream)430 cudaError_t CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank, int batch,
431 int SOffSet, int maxtime, int numclass, const int *sequence_length, int *label_squence_length,
432 int *cum_labels_length, T *cost, T *grads, T *prob_num, bool ignore_longer_outputs_than_inputs,
433 cudaStream_t stream) {
434 ProbInitKernel<<<GET_BLOCKS(maxtime * batch * numclass), GET_THREADS, 0, stream>>>(prob_num,
435 maxtime * batch * numclass);
436 CTCLossKernel<<<GET_BLOCKS(batch), GET_THREADS, 0, stream>>>(
437 log_alpha_b, log_beta_b, softmax_probs, label_value_with_blank, batch, SOffSet, maxtime, numclass, sequence_length,
438 label_squence_length, cum_labels_length, cost, grads, prob_num, ignore_longer_outputs_than_inputs);
439 return GetCudaStatus();
440 }
441
442 template CUDA_LIB_EXPORT cudaError_t CalculateFwdVar<float>(
443 float *log_alpha_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length,
444 bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
445 int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
446
447 template CUDA_LIB_EXPORT cudaError_t CalculateBwdVar<float>(
448 float *log_beta_b, int *label_value_with_blank, float *softmax_probs, const int *sequence_length,
449 bool ctc_merge_repeated, int batch, int SOffSet, int maxtime, int blank, int *label_squence_length,
450 int *cum_labels_length, bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
451
452 template CUDA_LIB_EXPORT cudaError_t InnerSoftMax<float>(const float *probs, float *softmax_probs,
453 const int *sequence_length, int max_time, int batch,
454 int numclass, cudaStream_t stream);
455
456 template CUDA_LIB_EXPORT cudaError_t CTCLoss<float>(float *log_alpha_b, float *log_beta_b, float *softmax_probs,
457 int *label_value_with_blank, int batch, int SOffSet, int maxtime,
458 int numclass, const int *sequence_length, int *label_squence_length,
459 int *cum_labels_length, float *cost, float *grads, float *prob_num,
460 bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
461