1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/cpu/ctcloss_cpu_kernel.h"
18 #include "runtime/device/cpu/cpu_device_address.h"
19
20 namespace mindspore {
21 namespace kernel {
22 namespace {
23 constexpr size_t kCTCLossInputsNum = 4;
24 constexpr size_t kCTCLossOutputsNum = 2;
25
26 template <typename T>
LogSumExp(const T logprob1,const T logprob2)27 inline T LogSumExp(const T logprob1, const T logprob2) {
28 T kLogZero_ = -std::numeric_limits<T>::infinity();
29 if (logprob1 <= kLogZero_) {
30 return logprob2;
31 }
32 if (logprob2 <= kLogZero_) {
33 return logprob1;
34 }
35 return (logprob1 > logprob2) ? logprob1 + static_cast<T>(log1p(exp(logprob2 - logprob1)))
36 : logprob2 + static_cast<T>(log1p(exp(logprob1 - logprob2)));
37 }
38
39 template <typename T>
InnerSoftMax(const T * inputs_addr,std::vector<std::vector<T>> * softmax_probs,const uint32_t sequence_length,size_t num_class,size_t batch_size,size_t b)40 void InnerSoftMax(const T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, const uint32_t sequence_length,
41 size_t num_class, size_t batch_size, size_t b) {
42 for (size_t t = 0; t < sequence_length; ++t) {
43 auto maxCoeff = static_cast<T>(0);
44 auto sumCoeff = static_cast<T>(0);
45
46 for (size_t c = 0; c < num_class; ++c) {
47 if (inputs_addr[t * batch_size * num_class + b * num_class + c] > maxCoeff) {
48 maxCoeff = inputs_addr[t * batch_size * num_class + b * num_class + c];
49 }
50 }
51
52 for (size_t c = 0; c < num_class; ++c) {
53 sumCoeff += static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
54 (*softmax_probs)[c][t] =
55 static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
56 }
57
58 for (size_t c = 0; c < num_class; ++c) {
59 (*softmax_probs)[c][t] /= sumCoeff;
60 }
61 }
62 }
63
64 template <typename T>
MatrixFromVector(uint32_t row,uint32_t col,std::vector<std::vector<T>> * array2D,const T init_value)65 void MatrixFromVector(uint32_t row, uint32_t col, std::vector<std::vector<T>> *array2D, const T init_value) {
66 array2D->resize(row);
67 for (size_t i = 0; i < row; ++i) {
68 (*array2D)[i].resize(col, init_value);
69 }
70 }
71 } // namespace
72
InitKernel(const CNodePtr & kernel_node)73 void CTCLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
74 MS_EXCEPTION_IF_NULL(kernel_node);
75 kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
76 probs_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
77 indices_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
78 labels_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
79 dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
80
81 if (probs_shape_.size() != 3) {
82 MS_LOG(EXCEPTION) << "Probs dims: " << probs_shape_.size() << " not support.";
83 }
84 if (labels_dims_.size() != 1) {
85 MS_LOG(EXCEPTION) << "Labels dims: " << labels_dims_.size() << " not support.";
86 }
87 if (indices_dims_.size() != 2) {
88 MS_LOG(EXCEPTION) << "Labels indice dims: " << indices_dims_.size() << " not support.";
89 }
90
91 preprocess_collapse_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, PCR);
92 ctc_merge_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, CTR);
93 ignore_longer_outputs_than_inputs_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, ILOTI);
94 max_time_ = probs_shape_[0];
95 batch_size_ = probs_shape_[1];
96 num_class_ = probs_shape_[2];
97 blank_index_ = num_class_ - 1;
98 }
99
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)100 bool CTCLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
101 const std::vector<kernel::AddressPtr> &outputs) {
102 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_);
103 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_);
104 if (dtype_ == kNumberTypeFloat16) {
105 LaunchKernel<float16>(inputs, outputs);
106 } else if (dtype_ == kNumberTypeFloat32) {
107 LaunchKernel<float>(inputs, outputs);
108 } else {
109 MS_LOG(EXCEPTION) << kernel_name_ << " only support float16 and float32 on CPU, but got "
110 << TypeIdToType(dtype_)->ToString();
111 }
112 return true;
113 }
114
115 template <typename TT>
CalculateFwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_alpha_b) const116 void CTCLossCPUKernel::CalculateFwdVar(const std::vector<uint32_t> &label_with_blank,
117 const std::vector<std::vector<TT>> &y,
118 std::vector<std::vector<TT>> *log_alpha_b) const {
119 int U = label_with_blank.size();
120 int T = (*log_alpha_b)[0].size();
121 TT kLogZero_ = -std::numeric_limits<TT>::infinity();
122
123 (*log_alpha_b)[0][0] = static_cast<TT>(log(y[blank_index_][0]));
124 auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_;
125 if (label_with_blank.size() > 1) {
126 (*log_alpha_b)[1][0] = static_cast<TT>(log(y[label_0][0]));
127 }
128
129 for (int t = 1; t < T; ++t) {
130 int low = std::max(0, U - (2 * (T - t)));
131 int high = std::min(U, 2 * (t + 1));
132 for (int u = low; u < high; ++u) {
133 auto sum_log_alpha_b = kLogZero_;
134 if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
135 sum_log_alpha_b = (*log_alpha_b)[u][t - 1];
136 }
137
138 if (u > 0) {
139 sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 1][t - 1]);
140 }
141
142 if (u > 1) {
143 bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u - 2]);
144 if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
145 sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 2][t - 1]);
146 }
147 }
148
149 (*log_alpha_b)[u][t] =
150 static_cast<TT>(log(static_cast<TT>(y[label_with_blank[IntToSize(u)]][IntToSize(t)]))) + sum_log_alpha_b;
151 }
152 }
153 }
154
155 template <typename TT>
CalculateBwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_beta_b) const156 void CTCLossCPUKernel::CalculateBwdVar(const std::vector<uint32_t> &label_with_blank,
157 const std::vector<std::vector<TT>> &y,
158 std::vector<std::vector<TT>> *log_beta_b) const {
159 int T = (*log_beta_b)[0].size();
160 int U = label_with_blank.size();
161 if (U > 1) {
162 for (int u = U - 2; u < U; ++u) {
163 (*log_beta_b)[u][T - 1] = TT(0);
164 }
165 } else {
166 (*log_beta_b)[0][T - 1] = TT(0);
167 (*log_beta_b)[0][T - 2] = TT(0);
168 }
169
170 for (int t = T - 2; t >= 0; --t) {
171 int low = std::max(0, U - (2 * (T - t)));
172 int high = std::min(U, 2 * (t + 1));
173 for (int u = low; u < high; ++u) {
174 if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
175 (*log_beta_b)[u][t] =
176 LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u][t + 1] + TT(log(y[label_with_blank[u]][t + 1])));
177 }
178
179 if (u + 1 < U) {
180 (*log_beta_b)[u][t] =
181 LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 1][t + 1] + TT(log(y[label_with_blank[u + 1]][t + 1])));
182 }
183
184 if (u + 2 < U) {
185 bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u + 2]);
186 if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
187 (*log_beta_b)[u][t] =
188 LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 2][t + 1] + TT(log(y[label_with_blank[u + 2]][t + 1])));
189 }
190 }
191 }
192 }
193 }
194
195 template <typename TT>
CalculateGrad(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,const std::vector<std::vector<TT>> & log_alpha_b,const std::vector<std::vector<TT>> & log_beta_b,const TT log_pzx,std::vector<std::vector<TT>> * dy) const196 void CTCLossCPUKernel::CalculateGrad(const std::vector<uint32_t> &label_with_blank,
197 const std::vector<std::vector<TT>> &y,
198 const std::vector<std::vector<TT>> &log_alpha_b,
199 const std::vector<std::vector<TT>> &log_beta_b, const TT log_pzx,
200 std::vector<std::vector<TT>> *dy) const {
201 auto dy_b = dy;
202 TT kLogZero_ = -std::numeric_limits<TT>::infinity();
203 if (log_pzx <= kLogZero_) {
204 MS_LOG(INFO) << "No valid path found";
205 return;
206 }
207
208 size_t L = y.size();
209 size_t T = y[0].size();
210 size_t U = label_with_blank.size();
211
212 for (size_t t = 0; t < T; ++t) {
213 std::vector<TT> prob_sum(L, kLogZero_);
214
215 for (size_t u = 0; u < U; ++u) {
216 uint32_t l = label_with_blank[u];
217 prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]);
218 }
219 for (size_t l = 0; l < L; ++l) {
220 (*dy_b)[l][t] = y[l][t] - static_cast<TT>(exp(prob_sum[l] - log_pzx));
221 }
222 }
223 }
224
GenLabelWithBlank(const uint32_t * seq_len,const std::vector<std::vector<uint32_t>> & batch_label,std::vector<std::vector<uint32_t>> * label_with_blank) const225 void CTCLossCPUKernel::GenLabelWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label,
226 std::vector<std::vector<uint32_t>> *label_with_blank) const {
227 for (size_t b = 0; b < batch_size_; ++b) {
228 std::vector<uint32_t> l;
229 const std::vector<uint32_t> &label = batch_label[b];
230 bool has_blank = false;
231 for (size_t i = 0; i < label.size(); ++i) {
232 if (i == 0 || !preprocess_collapse_repeated_ || label[i] != label[i - 1]) {
233 if (label[i] >= num_class_ - 1) {
234 has_blank = true;
235 } else {
236 if (has_blank) {
237 MS_LOG(EXCEPTION) << "Invalid labels(index >= num_class - 1) should not appear between two valid labels";
238 }
239 l.push_back(label[i]);
240 }
241 }
242 }
243 if (!ignore_longer_outputs_than_inputs_ && l.size() > seq_len[b]) {
244 MS_LOG(EXCEPTION) << "Input time(sequence length) should greater than output size(label length), but gets "
245 << seq_len[b] << "< " << l.size();
246 }
247
248 (*label_with_blank)[b].reserve(2 * l.size() + 1);
249 for (auto l_i : l) {
250 (*label_with_blank)[b].push_back(blank_index_);
251 (*label_with_blank)[b].push_back(l_i);
252 }
253 (*label_with_blank)[b].push_back(blank_index_);
254 }
255 }
256
257 template <typename T>
LaunchKernel(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & outputs) const258 void CTCLossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
259 const std::vector<AddressPtr> &outputs) const {
260 const auto *inputs_addr = reinterpret_cast<T *>(inputs[0]->addr);
261 const auto *labels_indices_addr = reinterpret_cast<uint64_t *>(inputs[1]->addr);
262 const auto *labels_values_addr = reinterpret_cast<uint32_t *>(inputs[2]->addr);
263 const auto *sequence_length_addr = reinterpret_cast<uint32_t *>(inputs[3]->addr);
264 auto *loss_addr = reinterpret_cast<T *>(outputs[0]->addr);
265 auto *gradient_addr = reinterpret_cast<T *>(outputs[1]->addr);
266
267 std::vector<std::vector<uint32_t>> label_batch;
268 std::vector<std::vector<uint32_t>> labels_with_blank;
269 std::vector<uint64_t> each_label_length;
270
271 label_batch.resize(batch_size_);
272 labels_with_blank.resize(batch_size_);
273 each_label_length.resize(batch_size_, 0);
274
275 T kLogZero_ = -std::numeric_limits<T>::infinity();
276 // check validation of sequence length
277 for (size_t b = 0; b < batch_size_; ++b) {
278 if (sequence_length_addr[b] == static_cast<uint32_t>(0)) {
279 MS_LOG(EXCEPTION) << "Sequence length should > 0, but gets " << sequence_length_addr[b];
280 }
281 if (sequence_length_addr[b] > max_time_) {
282 MS_LOG(EXCEPTION) << "Max time should be greater than sequence length, but gets " << max_time_ << " < "
283 << sequence_length_addr[b];
284 }
285 }
286 for (size_t i = 0; i < indices_dims_[0]; ++i) {
287 const size_t factor = 2;
288 auto index = labels_indices_addr[i * factor];
289 if (index >= SizeToUlong(each_label_length.size())) {
290 MS_LOG(EXCEPTION) << "Index: " << index << "out of the bounds of the vector.";
291 }
292 each_label_length[index]++;
293 }
294
295 // convert label format of label_value and label_indices to batch_label
296 uint64_t cum_sum = 0;
297 for (size_t b = 0; b < batch_size_; ++b) {
298 std::vector<uint32_t> *b_value = &label_batch[b];
299 for (size_t l = 0; l < each_label_length[b]; ++l) {
300 b_value->push_back(labels_values_addr[cum_sum + l]);
301 }
302 cum_sum += each_label_length[b];
303 }
304
305 // convert label to label with blank
306 GenLabelWithBlank(sequence_length_addr, label_batch, &labels_with_blank);
307
308 for (size_t b = 0; b < batch_size_; ++b) {
309 std::vector<uint32_t> label_with_blank = labels_with_blank[b];
310 // y_b [num_class, sequence_length]
311 std::vector<std::vector<T>> y_b;
312 std::vector<std::vector<T>> dy;
313 std::vector<std::vector<T>> log_alpha_b;
314 std::vector<std::vector<T>> log_beta_b;
315 MatrixFromVector(num_class_, sequence_length_addr[b], &y_b, kLogZero_);
316 MatrixFromVector(y_b.size(), y_b[0].size(), &dy, T(0));
317 MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_alpha_b, kLogZero_);
318 MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_beta_b, kLogZero_);
319 InnerSoftMax(inputs_addr, &y_b, sequence_length_addr[b], num_class_, batch_size_, b);
320 CalculateFwdVar(label_with_blank, y_b, &log_alpha_b);
321 CalculateBwdVar(label_with_blank, y_b, &log_beta_b);
322
323 T log_pzx = kLogZero_;
324 for (size_t u = 0; u < label_with_blank.size(); ++u) {
325 log_pzx = LogSumExp(log_pzx, log_alpha_b[u][0] + log_beta_b[u][0]);
326 }
327 loss_addr[b] = -log_pzx;
328 CalculateGrad(label_with_blank, y_b, log_alpha_b, log_beta_b, log_pzx, &dy);
329
330 for (size_t t = 0; t < sequence_length_addr[b]; ++t) {
331 for (size_t c = 0; c < num_class_; ++c) {
332 gradient_addr[t * batch_size_ * num_class_ + b * num_class_ + c] = dy[c][t];
333 }
334 }
335 }
336 }
337 } // namespace kernel
338 } // namespace mindspore
339