• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/cpu/kernel/ctcloss_cpu_kernel.h"
18 #include <map>
19 #include "plugin/device/cpu/hal/device/cpu_device_address.h"
20 
21 namespace mindspore {
22 namespace kernel {
23 namespace {
24 constexpr size_t kCTCLossInputsNum = 4;
25 constexpr size_t kCTCLossOutputsNum = 2;
26 
27 template <typename T>
LogSumExp(const T logprob1,const T logprob2)28 inline T LogSumExp(const T logprob1, const T logprob2) {
29   T kLogZero_ = -std::numeric_limits<T>::infinity();
30   if (logprob1 <= kLogZero_) {
31     return logprob2;
32   }
33   if (logprob2 <= kLogZero_) {
34     return logprob1;
35   }
36   return (logprob1 > logprob2) ? logprob1 + static_cast<T>(log1p(exp(logprob2 - logprob1)))
37                                : logprob2 + static_cast<T>(log1p(exp(logprob1 - logprob2)));
38 }
39 
40 template <typename T>
InnerSoftMax(const T * inputs_addr,std::vector<std::vector<T>> * softmax_probs,const uint32_t sequence_length,size_t num_class,size_t batch_size,size_t b)41 void InnerSoftMax(const T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, const uint32_t sequence_length,
42                   size_t num_class, size_t batch_size, size_t b) {
43   for (size_t t = 0; t < sequence_length; ++t) {
44     auto maxCoeff = static_cast<T>(0);
45     auto sumCoeff = static_cast<T>(0);
46 
47     for (size_t c = 0; c < num_class; ++c) {
48       if (inputs_addr[t * batch_size * num_class + b * num_class + c] > maxCoeff) {
49         maxCoeff = inputs_addr[t * batch_size * num_class + b * num_class + c];
50       }
51     }
52 
53     for (size_t c = 0; c < num_class; ++c) {
54       sumCoeff += static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
55       (*softmax_probs)[c][t] =
56         static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
57     }
58 
59     for (size_t c = 0; c < num_class; ++c) {
60       (*softmax_probs)[c][t] /= sumCoeff;
61     }
62   }
63 }
64 
65 template <typename T>
MatrixFromVector(uint32_t row,uint32_t col,std::vector<std::vector<T>> * array2D,const T init_value)66 void MatrixFromVector(uint32_t row, uint32_t col, std::vector<std::vector<T>> *array2D, const T init_value) {
67   array2D->resize(row);
68   for (size_t i = 0; i < row; ++i) {
69     (*array2D)[i].resize(col, init_value);
70   }
71 }
72 }  // namespace
73 
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)74 bool CTCLossCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
75   CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_);
76   CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_);
77 
78   preprocess_collapse_repeated_ = GetValue<bool>(primitive_->GetAttr(PCR));
79   ctc_merge_repeated_ = GetValue<bool>(primitive_->GetAttr(CTR));
80   ignore_longer_outputs_than_inputs_ = GetValue<bool>(primitive_->GetAttr(ILOTI));
81   return true;
82 }
83 
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)84 int CTCLossCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
85   if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) {
86     return ret;
87   }
88   probs_shape_ = inputs[0]->GetShapeVector();
89   indices_dims_ = inputs[1]->GetShapeVector();
90   labels_dims_ = inputs[2]->GetShapeVector();
91   dtype_ = inputs[0]->dtype_id();
92 
93   if (probs_shape_.size() != 3) {
94     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'probs' must be 3-D, but got " << probs_shape_.size()
95                       << "-D.";
96   }
97   if (labels_dims_.size() != 1) {
98     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'labels' must be 1-D, but got " << labels_dims_.size()
99                       << "-D.";
100   }
101   if (indices_dims_.size() != 2) {
102     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'labels_indices' must be 2-D, but got "
103                       << indices_dims_.size() << "-D.";
104   }
105 
106   max_time_ = LongToSize(probs_shape_[0]);
107   batch_size_ = LongToSize(probs_shape_[1]);
108   num_class_ = LongToSize(probs_shape_[2]);
109   blank_index_ = num_class_ - 1;
110   return KRET_OK;
111 }
112 
Launch(const std::vector<kernel::KernelTensor * > & inputs,const std::vector<kernel::KernelTensor * > &,const std::vector<kernel::KernelTensor * > & outputs)113 bool CTCLossCpuKernelMod::Launch(const std::vector<kernel::KernelTensor *> &inputs,
114                                  const std::vector<kernel::KernelTensor *> &,
115                                  const std::vector<kernel::KernelTensor *> &outputs) {
116   CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_);
117   CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_);
118   if (dtype_ == kNumberTypeFloat16) {
119     LaunchKernel<float16>(inputs, outputs);
120   } else if (dtype_ == kNumberTypeFloat32) {
121     LaunchKernel<float>(inputs, outputs);
122   } else {
123     MS_LOG(EXCEPTION) << "For '" << kernel_name_
124                       << "', the dtype of input 'x' must be float16 or float32 on CPU, but got "
125                       << TypeIdToType(dtype_)->ToString();
126   }
127   return true;
128 }
129 
130 template <typename TT>
CalculateFwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_alpha_b) const131 void CTCLossCpuKernelMod::CalculateFwdVar(const std::vector<uint32_t> &label_with_blank,
132                                           const std::vector<std::vector<TT>> &y,
133                                           std::vector<std::vector<TT>> *log_alpha_b) const {
134   int U = label_with_blank.size();
135   int T = (*log_alpha_b)[0].size();
136   TT kLogZero_ = -std::numeric_limits<TT>::infinity();
137 
138   (*log_alpha_b)[0][0] = static_cast<TT>(log(y[blank_index_][0]));
139   auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_;
140   if (label_with_blank.size() > 1) {
141     (*log_alpha_b)[1][0] = static_cast<TT>(log(y[label_0][0]));
142   }
143 
144   for (int t = 1; t < T; ++t) {
145     int low = std::max(0, U - (2 * (T - t)));
146     int high = std::min(U, 2 * (t + 1));
147     for (int u = low; u < high; ++u) {
148       auto sum_log_alpha_b = kLogZero_;
149       if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
150         sum_log_alpha_b = (*log_alpha_b)[u][t - 1];
151       }
152 
153       if (u > 0) {
154         sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 1][t - 1]);
155       }
156 
157       if (u > 1) {
158         bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u - 2]);
159         if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
160           sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 2][t - 1]);
161         }
162       }
163 
164       (*log_alpha_b)[u][t] =
165         static_cast<TT>(log(static_cast<TT>(y[label_with_blank[IntToSize(u)]][IntToSize(t)]))) + sum_log_alpha_b;
166     }
167   }
168 }
169 
170 template <typename TT>
CalculateBwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_beta_b) const171 void CTCLossCpuKernelMod::CalculateBwdVar(const std::vector<uint32_t> &label_with_blank,
172                                           const std::vector<std::vector<TT>> &y,
173                                           std::vector<std::vector<TT>> *log_beta_b) const {
174   int T = (*log_beta_b)[0].size();
175   int U = label_with_blank.size();
176   if (U > 1) {
177     for (int u = U - 2; u < U; ++u) {
178       (*log_beta_b)[u][T - 1] = TT(0);
179     }
180   } else {
181     (*log_beta_b)[0][T - 1] = TT(0);
182     (*log_beta_b)[0][T - 2] = TT(0);
183   }
184 
185   for (int t = T - 2; t >= 0; --t) {
186     int low = std::max(0, U - (2 * (T - t)));
187     int high = std::min(U, 2 * (t + 1));
188     for (int u = low; u < high; ++u) {
189       if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
190         (*log_beta_b)[u][t] =
191           LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u][t + 1] + TT(log(y[label_with_blank[u]][t + 1])));
192       }
193 
194       if (u + 1 < U) {
195         (*log_beta_b)[u][t] =
196           LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 1][t + 1] + TT(log(y[label_with_blank[u + 1]][t + 1])));
197       }
198 
199       if (u + 2 < U) {
200         bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u + 2]);
201         if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
202           (*log_beta_b)[u][t] =
203             LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 2][t + 1] + TT(log(y[label_with_blank[u + 2]][t + 1])));
204         }
205       }
206     }
207   }
208 }
209 
210 template <typename TT>
CalculateGrad(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,const std::vector<std::vector<TT>> & log_alpha_b,const std::vector<std::vector<TT>> & log_beta_b,const TT log_pzx,std::vector<std::vector<TT>> * dy) const211 void CTCLossCpuKernelMod::CalculateGrad(const std::vector<uint32_t> &label_with_blank,
212                                         const std::vector<std::vector<TT>> &y,
213                                         const std::vector<std::vector<TT>> &log_alpha_b,
214                                         const std::vector<std::vector<TT>> &log_beta_b, const TT log_pzx,
215                                         std::vector<std::vector<TT>> *dy) const {
216   auto dy_b = dy;
217   TT kLogZero_ = -std::numeric_limits<TT>::infinity();
218   if (log_pzx <= kLogZero_) {
219     MS_LOG(INFO) << "No valid path found";
220     return;
221   }
222 
223   size_t L = y.size();
224   size_t T = y[0].size();
225   size_t U = label_with_blank.size();
226 
227   for (size_t t = 0; t < T; ++t) {
228     std::vector<TT> prob_sum(L, kLogZero_);
229 
230     for (size_t u = 0; u < U; ++u) {
231       uint32_t l = label_with_blank[u];
232       prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]);
233     }
234     for (size_t l = 0; l < L; ++l) {
235       (*dy_b)[l][t] = y[l][t] - static_cast<TT>(exp(prob_sum[l] - log_pzx));
236     }
237   }
238 }
239 
GenLabelWithBlank(const uint32_t * seq_len,const std::vector<std::vector<uint32_t>> & batch_label,std::vector<std::vector<uint32_t>> * label_with_blank) const240 void CTCLossCpuKernelMod::GenLabelWithBlank(const uint32_t *seq_len,
241                                             const std::vector<std::vector<uint32_t>> &batch_label,
242                                             std::vector<std::vector<uint32_t>> *label_with_blank) const {
243   for (size_t b = 0; b < batch_size_; ++b) {
244     std::vector<uint32_t> l;
245     const std::vector<uint32_t> &label = batch_label[b];
246     bool has_blank = false;
247     for (size_t i = 0; i < label.size(); ++i) {
248       if (i == 0 || !preprocess_collapse_repeated_ || label[i] != label[i - 1]) {
249         if (label[i] >= num_class_ - 1) {
250           has_blank = true;
251         } else {
252           if (has_blank) {
253             MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the value of labels_values[" << i
254                               << "] must be in the range of [0, num_classes), but got " << label[i];
255           }
256           l.push_back(label[i]);
257         }
258       }
259     }
260     if (!ignore_longer_outputs_than_inputs_ && l.size() > seq_len[b]) {
261       MS_LOG(EXCEPTION) << "For '" << kernel_name_
262                         << ", input time(sequence length) must be greater than "
263                            "output size(label length), but got sequence length: "
264                         << seq_len[b] << " and label length: " << l.size();
265     }
266 
267     (*label_with_blank)[b].reserve(2 * l.size() + 1);
268     for (auto l_i : l) {
269       (*label_with_blank)[b].push_back(blank_index_);
270       (*label_with_blank)[b].push_back(l_i);
271     }
272     (*label_with_blank)[b].push_back(blank_index_);
273   }
274 }
275 
276 template <typename T>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs) const277 void CTCLossCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
278                                        const std::vector<KernelTensor *> &outputs) const {
279   const auto *inputs_addr = reinterpret_cast<T *>(inputs[0]->device_ptr());
280   const auto *labels_indices_addr = reinterpret_cast<uint64_t *>(inputs[1]->device_ptr());
281   const auto *labels_values_addr = reinterpret_cast<uint32_t *>(inputs[2]->device_ptr());
282   const auto *sequence_length_addr = reinterpret_cast<uint32_t *>(inputs[3]->device_ptr());
283   auto *loss_addr = reinterpret_cast<T *>(outputs[0]->device_ptr());
284   auto *gradient_addr = reinterpret_cast<T *>(outputs[1]->device_ptr());
285 
286   std::vector<std::vector<uint32_t>> label_batch;
287   std::vector<std::vector<uint32_t>> labels_with_blank;
288   std::vector<uint64_t> each_label_length;
289 
290   label_batch.resize(batch_size_);
291   labels_with_blank.resize(batch_size_);
292   each_label_length.resize(batch_size_, 0);
293 
294   T kLogZero_ = -std::numeric_limits<T>::infinity();
295   // check validation of sequence length
296   for (size_t b = 0; b < batch_size_; ++b) {
297     if (sequence_length_addr[b] == static_cast<uint32_t>(0)) {
298       MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", the 'sequence_length' must be greater than 0, but got "
299                         << sequence_length_addr[b] << ".";
300     }
301     if (sequence_length_addr[b] > max_time_) {
302       MS_LOG(EXCEPTION) << "For '" << kernel_name_
303                         << ", the 'max_time'(the 1st dimension value of 'probs') must be "
304                            "greater than or equal to 'sequence_length', but got 'max_time': "
305                         << max_time_ << " and 'sequence_length': " << sequence_length_addr[b];
306     }
307   }
308   for (size_t i = 0; i < LongToSize(indices_dims_[0]); ++i) {
309     const size_t factor = 2;
310     auto index = labels_indices_addr[i * factor];
311     if (index >= SizeToUlong(each_label_length.size())) {
312       MS_LOG(EXCEPTION) << "For '" << kernel_name_
313                         << ", 'index' must be less than the length of 'label', but got 'index': " << index
314                         << " and the length of 'label': " << SizeToUlong(each_label_length.size());
315     }
316     each_label_length[index]++;
317   }
318 
319   // convert label format of label_value and label_indices to batch_label
320   uint64_t cum_sum = 0;
321   for (size_t b = 0; b < batch_size_; ++b) {
322     std::vector<uint32_t> *b_value = &label_batch[b];
323     for (size_t l = 0; l < each_label_length[b]; ++l) {
324       b_value->push_back(labels_values_addr[cum_sum + l]);
325     }
326     cum_sum += each_label_length[b];
327   }
328 
329   // convert label to label with blank
330   GenLabelWithBlank(sequence_length_addr, label_batch, &labels_with_blank);
331 
332   for (size_t b = 0; b < batch_size_; ++b) {
333     std::vector<uint32_t> label_with_blank = labels_with_blank[b];
334     // y_b [num_class, sequence_length]
335     std::vector<std::vector<T>> y_b;
336     std::vector<std::vector<T>> dy;
337     std::vector<std::vector<T>> log_alpha_b;
338     std::vector<std::vector<T>> log_beta_b;
339     MatrixFromVector(num_class_, sequence_length_addr[b], &y_b, kLogZero_);
340     MatrixFromVector(y_b.size(), y_b[0].size(), &dy, T(0));
341     MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_alpha_b, kLogZero_);
342     MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_beta_b, kLogZero_);
343     InnerSoftMax(inputs_addr, &y_b, sequence_length_addr[b], num_class_, batch_size_, b);
344     CalculateFwdVar(label_with_blank, y_b, &log_alpha_b);
345     CalculateBwdVar(label_with_blank, y_b, &log_beta_b);
346 
347     T log_pzx = kLogZero_;
348     for (size_t u = 0; u < label_with_blank.size(); ++u) {
349       log_pzx = LogSumExp(log_pzx, log_alpha_b[u][0] + log_beta_b[u][0]);
350     }
351     loss_addr[b] = -log_pzx;
352     CalculateGrad(label_with_blank, y_b, log_alpha_b, log_beta_b, log_pzx, &dy);
353 
354     for (size_t t = 0; t < sequence_length_addr[b]; ++t) {
355       for (size_t c = 0; c < num_class_; ++c) {
356         gradient_addr[t * batch_size_ * num_class_ + b * num_class_ + c] = dy[c][t];
357       }
358     }
359   }
360 }
361 
GetOpSupport()362 std::vector<KernelAttr> CTCLossCpuKernelMod::GetOpSupport() {
363   static std::vector<KernelAttr> support_list = {KernelAttr()
364                                                    .AddInputAttr(kNumberTypeFloat16)
365                                                    .AddInputAttr(kNumberTypeInt64)
366                                                    .AddInputAttr(kNumberTypeInt32)
367                                                    .AddInputAttr(kNumberTypeInt32)
368                                                    .AddOutputAttr(kNumberTypeFloat16)
369                                                    .AddOutputAttr(kNumberTypeFloat16),
370                                                  KernelAttr()
371                                                    .AddInputAttr(kNumberTypeFloat32)
372                                                    .AddInputAttr(kNumberTypeInt64)
373                                                    .AddInputAttr(kNumberTypeInt32)
374                                                    .AddInputAttr(kNumberTypeInt32)
375                                                    .AddOutputAttr(kNumberTypeFloat32)
376                                                    .AddOutputAttr(kNumberTypeFloat32)};
377 
378   return support_list;
379 }
380 
381 MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, CTCLoss, CTCLossCpuKernelMod);
382 }  // namespace kernel
383 }  // namespace mindspore
384