• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_
18 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_
19 #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
20 
21 template <typename T>
22 CUDA_LIB_EXPORT cudaError_t CalculateFwdVar(T *log_alpha_b, int *label_value_with_blank, T *softmax_probs,
23                                             const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
24                                             int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
25                                             bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
26 
27 template <typename T>
28 CUDA_LIB_EXPORT cudaError_t CalculateBwdVar(T *log_beta_b, int *label_value_with_blank, T *softmax_probs,
29                                             const int *sequence_length, bool ctc_merge_repeated, int batch, int SOffSet,
30                                             int maxtime, int blank, int *label_squence_length, int *cum_labels_length,
31                                             bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
32 
33 template <typename T>
34 CUDA_LIB_EXPORT cudaError_t InnerSoftMax(const T *probs, T *softmax_cost, const int *sequence_length, int max_time,
35                                          int batch, int numclass, cudaStream_t stream);
36 
37 CUDA_LIB_EXPORT cudaError_t GenLabelValuePCR(int *label_value_sp, int *label_value_pcr, int *label_squence_length,
38                                              int *cum_labels_length, int *max_labels_length, int batch,
39                                              cudaStream_t stream);
40 
41 CUDA_LIB_EXPORT cudaError_t GenLabelWithBlank(int *label_value, int *label_value_with_blank, int *label_squence_length,
42                                               int *precum_labels_length, int *cum_labels_length, int batch, int blank,
43                                               cudaStream_t stream);
44 
45 CUDA_LIB_EXPORT cudaError_t GenLabelValue(int *label_value_sp, const int64_t *label_indices, const int *label_values,
46                                           int *label_squence_length, int *cum_labels_length, int *max_labels_length,
47                                           int size, int blank, int batch, cudaStream_t stream);
48 
49 CUDA_LIB_EXPORT cudaError_t CalculatePreLength(int *label_squence_length, int *precum_labels_length,
50                                                int *cum_labels_length, int *max_labels_length,
51                                                const int64_t *label_indices, int batch, int size, cudaStream_t stream);
52 CUDA_LIB_EXPORT cudaError_t CalculateMaxSequence(const int *sequence_length, int *max_labels_length, int batch,
53                                                  cudaStream_t stream);
54 template <typename T>
55 CUDA_LIB_EXPORT cudaError_t CTCLoss(T *log_alpha_b, T *log_beta_b, T *softmax_probs, int *label_value_with_blank,
56                                     int batch, int SOffSet, int maxtime, int numclass, const int *sequence_length,
57                                     int *label_squence_length, int *cum_labels_length, T *cost, T *grads, T *prob_num,
58                                     bool ignore_longer_outputs_than_inputs, cudaStream_t stream);
59 #endif  // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_CTCLOSS_IMPL_CUH_
60