1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/cpu/kernel/ctcloss_cpu_kernel.h"
18 #include <map>
19 #include "plugin/device/cpu/hal/device/cpu_device_address.h"
20
21 namespace mindspore {
22 namespace kernel {
23 namespace {
24 constexpr size_t kCTCLossInputsNum = 4;
25 constexpr size_t kCTCLossOutputsNum = 2;
26
27 template <typename T>
LogSumExp(const T logprob1,const T logprob2)28 inline T LogSumExp(const T logprob1, const T logprob2) {
29 T kLogZero_ = -std::numeric_limits<T>::infinity();
30 if (logprob1 <= kLogZero_) {
31 return logprob2;
32 }
33 if (logprob2 <= kLogZero_) {
34 return logprob1;
35 }
36 return (logprob1 > logprob2) ? logprob1 + static_cast<T>(log1p(exp(logprob2 - logprob1)))
37 : logprob2 + static_cast<T>(log1p(exp(logprob1 - logprob2)));
38 }
39
40 template <typename T>
InnerSoftMax(const T * inputs_addr,std::vector<std::vector<T>> * softmax_probs,const uint32_t sequence_length,size_t num_class,size_t batch_size,size_t b)41 void InnerSoftMax(const T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, const uint32_t sequence_length,
42 size_t num_class, size_t batch_size, size_t b) {
43 for (size_t t = 0; t < sequence_length; ++t) {
44 auto maxCoeff = static_cast<T>(0);
45 auto sumCoeff = static_cast<T>(0);
46
47 for (size_t c = 0; c < num_class; ++c) {
48 if (inputs_addr[t * batch_size * num_class + b * num_class + c] > maxCoeff) {
49 maxCoeff = inputs_addr[t * batch_size * num_class + b * num_class + c];
50 }
51 }
52
53 for (size_t c = 0; c < num_class; ++c) {
54 sumCoeff += static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
55 (*softmax_probs)[c][t] =
56 static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
57 }
58
59 for (size_t c = 0; c < num_class; ++c) {
60 (*softmax_probs)[c][t] /= sumCoeff;
61 }
62 }
63 }
64
65 template <typename T>
MatrixFromVector(uint32_t row,uint32_t col,std::vector<std::vector<T>> * array2D,const T init_value)66 void MatrixFromVector(uint32_t row, uint32_t col, std::vector<std::vector<T>> *array2D, const T init_value) {
67 array2D->resize(row);
68 for (size_t i = 0; i < row; ++i) {
69 (*array2D)[i].resize(col, init_value);
70 }
71 }
72 } // namespace
73
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)74 bool CTCLossCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
75 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_);
76 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_);
77
78 preprocess_collapse_repeated_ = GetValue<bool>(primitive_->GetAttr(PCR));
79 ctc_merge_repeated_ = GetValue<bool>(primitive_->GetAttr(CTR));
80 ignore_longer_outputs_than_inputs_ = GetValue<bool>(primitive_->GetAttr(ILOTI));
81 return true;
82 }
83
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)84 int CTCLossCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
85 if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) {
86 return ret;
87 }
88 probs_shape_ = inputs[0]->GetShapeVector();
89 indices_dims_ = inputs[1]->GetShapeVector();
90 labels_dims_ = inputs[2]->GetShapeVector();
91 dtype_ = inputs[0]->dtype_id();
92
93 if (probs_shape_.size() != 3) {
94 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'probs' must be 3-D, but got " << probs_shape_.size()
95 << "-D.";
96 }
97 if (labels_dims_.size() != 1) {
98 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'labels' must be 1-D, but got " << labels_dims_.size()
99 << "-D.";
100 }
101 if (indices_dims_.size() != 2) {
102 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'labels_indices' must be 2-D, but got "
103 << indices_dims_.size() << "-D.";
104 }
105
106 max_time_ = LongToSize(probs_shape_[0]);
107 batch_size_ = LongToSize(probs_shape_[1]);
108 num_class_ = LongToSize(probs_shape_[2]);
109 blank_index_ = num_class_ - 1;
110 return KRET_OK;
111 }
112
Launch(const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > &,const std::vector<kernel::KernelTensor * > & outputs)113 bool CTCLossCpuKernelMod::Launch(const std::vector<kernel::KernelTensor *> &inputs,
114 const std::vector<kernel::KernelTensor *> &,
115 const std::vector<kernel::KernelTensor *> &outputs) {
116 CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_);
117 CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_);
118 if (dtype_ == kNumberTypeFloat16) {
119 LaunchKernel<float16>(inputs, outputs);
120 } else if (dtype_ == kNumberTypeFloat32) {
121 LaunchKernel<float>(inputs, outputs);
122 } else {
123 MS_LOG(EXCEPTION) << "For '" << kernel_name_
124 << "', the dtype of input 'x' must be float16 or float32 on CPU, but got "
125 << TypeIdToType(dtype_)->ToString();
126 }
127 return true;
128 }
129
130 template <typename TT>
CalculateFwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_alpha_b) const131 void CTCLossCpuKernelMod::CalculateFwdVar(const std::vector<uint32_t> &label_with_blank,
132 const std::vector<std::vector<TT>> &y,
133 std::vector<std::vector<TT>> *log_alpha_b) const {
134 int U = label_with_blank.size();
135 int T = (*log_alpha_b)[0].size();
136 TT kLogZero_ = -std::numeric_limits<TT>::infinity();
137
138 (*log_alpha_b)[0][0] = static_cast<TT>(log(y[blank_index_][0]));
139 auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_;
140 if (label_with_blank.size() > 1) {
141 (*log_alpha_b)[1][0] = static_cast<TT>(log(y[label_0][0]));
142 }
143
144 for (int t = 1; t < T; ++t) {
145 int low = std::max(0, U - (2 * (T - t)));
146 int high = std::min(U, 2 * (t + 1));
147 for (int u = low; u < high; ++u) {
148 auto sum_log_alpha_b = kLogZero_;
149 if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
150 sum_log_alpha_b = (*log_alpha_b)[u][t - 1];
151 }
152
153 if (u > 0) {
154 sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 1][t - 1]);
155 }
156
157 if (u > 1) {
158 bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u - 2]);
159 if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
160 sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 2][t - 1]);
161 }
162 }
163
164 (*log_alpha_b)[u][t] =
165 static_cast<TT>(log(static_cast<TT>(y[label_with_blank[IntToSize(u)]][IntToSize(t)]))) + sum_log_alpha_b;
166 }
167 }
168 }
169
170 template <typename TT>
CalculateBwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_beta_b) const171 void CTCLossCpuKernelMod::CalculateBwdVar(const std::vector<uint32_t> &label_with_blank,
172 const std::vector<std::vector<TT>> &y,
173 std::vector<std::vector<TT>> *log_beta_b) const {
174 int T = (*log_beta_b)[0].size();
175 int U = label_with_blank.size();
176 if (U > 1) {
177 for (int u = U - 2; u < U; ++u) {
178 (*log_beta_b)[u][T - 1] = TT(0);
179 }
180 } else {
181 (*log_beta_b)[0][T - 1] = TT(0);
182 (*log_beta_b)[0][T - 2] = TT(0);
183 }
184
185 for (int t = T - 2; t >= 0; --t) {
186 int low = std::max(0, U - (2 * (T - t)));
187 int high = std::min(U, 2 * (t + 1));
188 for (int u = low; u < high; ++u) {
189 if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
190 (*log_beta_b)[u][t] =
191 LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u][t + 1] + TT(log(y[label_with_blank[u]][t + 1])));
192 }
193
194 if (u + 1 < U) {
195 (*log_beta_b)[u][t] =
196 LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 1][t + 1] + TT(log(y[label_with_blank[u + 1]][t + 1])));
197 }
198
199 if (u + 2 < U) {
200 bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u + 2]);
201 if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
202 (*log_beta_b)[u][t] =
203 LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 2][t + 1] + TT(log(y[label_with_blank[u + 2]][t + 1])));
204 }
205 }
206 }
207 }
208 }
209
210 template <typename TT>
CalculateGrad(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,const std::vector<std::vector<TT>> & log_alpha_b,const std::vector<std::vector<TT>> & log_beta_b,const TT log_pzx,std::vector<std::vector<TT>> * dy) const211 void CTCLossCpuKernelMod::CalculateGrad(const std::vector<uint32_t> &label_with_blank,
212 const std::vector<std::vector<TT>> &y,
213 const std::vector<std::vector<TT>> &log_alpha_b,
214 const std::vector<std::vector<TT>> &log_beta_b, const TT log_pzx,
215 std::vector<std::vector<TT>> *dy) const {
216 auto dy_b = dy;
217 TT kLogZero_ = -std::numeric_limits<TT>::infinity();
218 if (log_pzx <= kLogZero_) {
219 MS_LOG(INFO) << "No valid path found";
220 return;
221 }
222
223 size_t L = y.size();
224 size_t T = y[0].size();
225 size_t U = label_with_blank.size();
226
227 for (size_t t = 0; t < T; ++t) {
228 std::vector<TT> prob_sum(L, kLogZero_);
229
230 for (size_t u = 0; u < U; ++u) {
231 uint32_t l = label_with_blank[u];
232 prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]);
233 }
234 for (size_t l = 0; l < L; ++l) {
235 (*dy_b)[l][t] = y[l][t] - static_cast<TT>(exp(prob_sum[l] - log_pzx));
236 }
237 }
238 }
239
GenLabelWithBlank(const uint32_t * seq_len,const std::vector<std::vector<uint32_t>> & batch_label,std::vector<std::vector<uint32_t>> * label_with_blank) const240 void CTCLossCpuKernelMod::GenLabelWithBlank(const uint32_t *seq_len,
241 const std::vector<std::vector<uint32_t>> &batch_label,
242 std::vector<std::vector<uint32_t>> *label_with_blank) const {
243 for (size_t b = 0; b < batch_size_; ++b) {
244 std::vector<uint32_t> l;
245 const std::vector<uint32_t> &label = batch_label[b];
246 bool has_blank = false;
247 for (size_t i = 0; i < label.size(); ++i) {
248 if (i == 0 || !preprocess_collapse_repeated_ || label[i] != label[i - 1]) {
249 if (label[i] >= num_class_ - 1) {
250 has_blank = true;
251 } else {
252 if (has_blank) {
253 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of labels_values[" << i
254 << "] must be in the range of [0, num_classes), but got " << label[i];
255 }
256 l.push_back(label[i]);
257 }
258 }
259 }
260 if (!ignore_longer_outputs_than_inputs_ && l.size() > seq_len[b]) {
261 MS_LOG(EXCEPTION) << "For '" << kernel_name_
262 << ", input time(sequence length) must be greater than "
263 "output size(label length), but got sequence length: "
264 << seq_len[b] << " and label length: " << l.size();
265 }
266
267 (*label_with_blank)[b].reserve(2 * l.size() + 1);
268 for (auto l_i : l) {
269 (*label_with_blank)[b].push_back(blank_index_);
270 (*label_with_blank)[b].push_back(l_i);
271 }
272 (*label_with_blank)[b].push_back(blank_index_);
273 }
274 }
275
276 template <typename T>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs) const277 void CTCLossCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
278 const std::vector<KernelTensor *> &outputs) const {
279 const auto *inputs_addr = reinterpret_cast<T *>(inputs[0]->device_ptr());
280 const auto *labels_indices_addr = reinterpret_cast<uint64_t *>(inputs[1]->device_ptr());
281 const auto *labels_values_addr = reinterpret_cast<uint32_t *>(inputs[2]->device_ptr());
282 const auto *sequence_length_addr = reinterpret_cast<uint32_t *>(inputs[3]->device_ptr());
283 auto *loss_addr = reinterpret_cast<T *>(outputs[0]->device_ptr());
284 auto *gradient_addr = reinterpret_cast<T *>(outputs[1]->device_ptr());
285
286 std::vector<std::vector<uint32_t>> label_batch;
287 std::vector<std::vector<uint32_t>> labels_with_blank;
288 std::vector<uint64_t> each_label_length;
289
290 label_batch.resize(batch_size_);
291 labels_with_blank.resize(batch_size_);
292 each_label_length.resize(batch_size_, 0);
293
294 T kLogZero_ = -std::numeric_limits<T>::infinity();
295 // check validation of sequence length
296 for (size_t b = 0; b < batch_size_; ++b) {
297 if (sequence_length_addr[b] == static_cast<uint32_t>(0)) {
298 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", the 'sequence_length' must be greater than 0, but got "
299 << sequence_length_addr[b] << ".";
300 }
301 if (sequence_length_addr[b] > max_time_) {
302 MS_LOG(EXCEPTION) << "For '" << kernel_name_
303 << ", the 'max_time'(the 1st dimension value of 'probs') must be "
304 "greater than or equal to 'sequence_length', but got 'max_time': "
305 << max_time_ << " and 'sequence_length': " << sequence_length_addr[b];
306 }
307 }
308 for (size_t i = 0; i < LongToSize(indices_dims_[0]); ++i) {
309 const size_t factor = 2;
310 auto index = labels_indices_addr[i * factor];
311 if (index >= SizeToUlong(each_label_length.size())) {
312 MS_LOG(EXCEPTION) << "For '" << kernel_name_
313 << ", 'index' must be less than the length of 'label', but got 'index': " << index
314 << " and the length of 'label': " << SizeToUlong(each_label_length.size());
315 }
316 each_label_length[index]++;
317 }
318
319 // convert label format of label_value and label_indices to batch_label
320 uint64_t cum_sum = 0;
321 for (size_t b = 0; b < batch_size_; ++b) {
322 std::vector<uint32_t> *b_value = &label_batch[b];
323 for (size_t l = 0; l < each_label_length[b]; ++l) {
324 b_value->push_back(labels_values_addr[cum_sum + l]);
325 }
326 cum_sum += each_label_length[b];
327 }
328
329 // convert label to label with blank
330 GenLabelWithBlank(sequence_length_addr, label_batch, &labels_with_blank);
331
332 for (size_t b = 0; b < batch_size_; ++b) {
333 std::vector<uint32_t> label_with_blank = labels_with_blank[b];
334 // y_b [num_class, sequence_length]
335 std::vector<std::vector<T>> y_b;
336 std::vector<std::vector<T>> dy;
337 std::vector<std::vector<T>> log_alpha_b;
338 std::vector<std::vector<T>> log_beta_b;
339 MatrixFromVector(num_class_, sequence_length_addr[b], &y_b, kLogZero_);
340 MatrixFromVector(y_b.size(), y_b[0].size(), &dy, T(0));
341 MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_alpha_b, kLogZero_);
342 MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_beta_b, kLogZero_);
343 InnerSoftMax(inputs_addr, &y_b, sequence_length_addr[b], num_class_, batch_size_, b);
344 CalculateFwdVar(label_with_blank, y_b, &log_alpha_b);
345 CalculateBwdVar(label_with_blank, y_b, &log_beta_b);
346
347 T log_pzx = kLogZero_;
348 for (size_t u = 0; u < label_with_blank.size(); ++u) {
349 log_pzx = LogSumExp(log_pzx, log_alpha_b[u][0] + log_beta_b[u][0]);
350 }
351 loss_addr[b] = -log_pzx;
352 CalculateGrad(label_with_blank, y_b, log_alpha_b, log_beta_b, log_pzx, &dy);
353
354 for (size_t t = 0; t < sequence_length_addr[b]; ++t) {
355 for (size_t c = 0; c < num_class_; ++c) {
356 gradient_addr[t * batch_size_ * num_class_ + b * num_class_ + c] = dy[c][t];
357 }
358 }
359 }
360 }
361
GetOpSupport()362 std::vector<KernelAttr> CTCLossCpuKernelMod::GetOpSupport() {
363 static std::vector<KernelAttr> support_list = {KernelAttr()
364 .AddInputAttr(kNumberTypeFloat16)
365 .AddInputAttr(kNumberTypeInt64)
366 .AddInputAttr(kNumberTypeInt32)
367 .AddInputAttr(kNumberTypeInt32)
368 .AddOutputAttr(kNumberTypeFloat16)
369 .AddOutputAttr(kNumberTypeFloat16),
370 KernelAttr()
371 .AddInputAttr(kNumberTypeFloat32)
372 .AddInputAttr(kNumberTypeInt64)
373 .AddInputAttr(kNumberTypeInt32)
374 .AddInputAttr(kNumberTypeInt32)
375 .AddOutputAttr(kNumberTypeFloat32)
376 .AddOutputAttr(kNumberTypeFloat32)};
377
378 return support_list;
379 }
380
381 MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CTCLoss, CTCLossCpuKernelMod);
382 } // namespace kernel
383 } // namespace mindspore
384