• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/cpu/ctcloss_cpu_kernel.h"
18 #include "runtime/device/cpu/cpu_device_address.h"
19 
20 namespace mindspore {
21 namespace kernel {
22 namespace {
23 constexpr size_t kCTCLossInputsNum = 4;
24 constexpr size_t kCTCLossOutputsNum = 2;
25 
26 template <typename T>
LogSumExp(const T logprob1,const T logprob2)27 inline T LogSumExp(const T logprob1, const T logprob2) {
28   T kLogZero_ = -std::numeric_limits<T>::infinity();
29   if (logprob1 <= kLogZero_) {
30     return logprob2;
31   }
32   if (logprob2 <= kLogZero_) {
33     return logprob1;
34   }
35   return (logprob1 > logprob2) ? logprob1 + static_cast<T>(log1p(exp(logprob2 - logprob1)))
36                                : logprob2 + static_cast<T>(log1p(exp(logprob1 - logprob2)));
37 }
38 
39 template <typename T>
InnerSoftMax(const T * inputs_addr,std::vector<std::vector<T>> * softmax_probs,const uint32_t sequence_length,size_t num_class,size_t batch_size,size_t b)40 void InnerSoftMax(const T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, const uint32_t sequence_length,
41                   size_t num_class, size_t batch_size, size_t b) {
42   for (size_t t = 0; t < sequence_length; ++t) {
43     auto maxCoeff = static_cast<T>(0);
44     auto sumCoeff = static_cast<T>(0);
45 
46     for (size_t c = 0; c < num_class; ++c) {
47       if (inputs_addr[t * batch_size * num_class + b * num_class + c] > maxCoeff) {
48         maxCoeff = inputs_addr[t * batch_size * num_class + b * num_class + c];
49       }
50     }
51 
52     for (size_t c = 0; c < num_class; ++c) {
53       sumCoeff += static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
54       (*softmax_probs)[c][t] =
55         static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
56     }
57 
58     for (size_t c = 0; c < num_class; ++c) {
59       (*softmax_probs)[c][t] /= sumCoeff;
60     }
61   }
62 }
63 
64 template <typename T>
MatrixFromVector(uint32_t row,uint32_t col,std::vector<std::vector<T>> * array2D,const T init_value)65 void MatrixFromVector(uint32_t row, uint32_t col, std::vector<std::vector<T>> *array2D, const T init_value) {
66   array2D->resize(row);
67   for (size_t i = 0; i < row; ++i) {
68     (*array2D)[i].resize(col, init_value);
69   }
70 }
71 }  // namespace
72 
InitKernel(const CNodePtr & kernel_node)73 void CTCLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
74   MS_EXCEPTION_IF_NULL(kernel_node);
75   kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
76   probs_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
77   indices_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
78   labels_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
79   dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
80 
81   if (probs_shape_.size() != 3) {
82     MS_LOG(EXCEPTION) << "Probs dims: " << probs_shape_.size() << " not support.";
83   }
84   if (labels_dims_.size() != 1) {
85     MS_LOG(EXCEPTION) << "Labels dims: " << labels_dims_.size() << " not support.";
86   }
87   if (indices_dims_.size() != 2) {
88     MS_LOG(EXCEPTION) << "Labels indice dims: " << indices_dims_.size() << " not support.";
89   }
90 
91   preprocess_collapse_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, PCR);
92   ctc_merge_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, CTR);
93   ignore_longer_outputs_than_inputs_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, ILOTI);
94   max_time_ = probs_shape_[0];
95   batch_size_ = probs_shape_[1];
96   num_class_ = probs_shape_[2];
97   blank_index_ = num_class_ - 1;
98 }
99 
Launch(const std::vector<kernel::AddressPtr> & inputs,const std::vector<kernel::AddressPtr> &,const std::vector<kernel::AddressPtr> & outputs)100 bool CTCLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
101                               const std::vector<kernel::AddressPtr> &outputs) {
102   CHECK_KERNEL_INPUTS_NUM(inputs.size(), kCTCLossInputsNum, kernel_name_);
103   CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kCTCLossOutputsNum, kernel_name_);
104   if (dtype_ == kNumberTypeFloat16) {
105     LaunchKernel<float16>(inputs, outputs);
106   } else if (dtype_ == kNumberTypeFloat32) {
107     LaunchKernel<float>(inputs, outputs);
108   } else {
109     MS_LOG(EXCEPTION) << kernel_name_ << " only support float16 and float32 on CPU, but got "
110                       << TypeIdToType(dtype_)->ToString();
111   }
112   return true;
113 }
114 
115 template <typename TT>
CalculateFwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_alpha_b) const116 void CTCLossCPUKernel::CalculateFwdVar(const std::vector<uint32_t> &label_with_blank,
117                                        const std::vector<std::vector<TT>> &y,
118                                        std::vector<std::vector<TT>> *log_alpha_b) const {
119   int U = label_with_blank.size();
120   int T = (*log_alpha_b)[0].size();
121   TT kLogZero_ = -std::numeric_limits<TT>::infinity();
122 
123   (*log_alpha_b)[0][0] = static_cast<TT>(log(y[blank_index_][0]));
124   auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_;
125   if (label_with_blank.size() > 1) {
126     (*log_alpha_b)[1][0] = static_cast<TT>(log(y[label_0][0]));
127   }
128 
129   for (int t = 1; t < T; ++t) {
130     int low = std::max(0, U - (2 * (T - t)));
131     int high = std::min(U, 2 * (t + 1));
132     for (int u = low; u < high; ++u) {
133       auto sum_log_alpha_b = kLogZero_;
134       if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
135         sum_log_alpha_b = (*log_alpha_b)[u][t - 1];
136       }
137 
138       if (u > 0) {
139         sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 1][t - 1]);
140       }
141 
142       if (u > 1) {
143         bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u - 2]);
144         if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
145           sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 2][t - 1]);
146         }
147       }
148 
149       (*log_alpha_b)[u][t] =
150         static_cast<TT>(log(static_cast<TT>(y[label_with_blank[IntToSize(u)]][IntToSize(t)]))) + sum_log_alpha_b;
151     }
152   }
153 }
154 
155 template <typename TT>
CalculateBwdVar(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,std::vector<std::vector<TT>> * log_beta_b) const156 void CTCLossCPUKernel::CalculateBwdVar(const std::vector<uint32_t> &label_with_blank,
157                                        const std::vector<std::vector<TT>> &y,
158                                        std::vector<std::vector<TT>> *log_beta_b) const {
159   int T = (*log_beta_b)[0].size();
160   int U = label_with_blank.size();
161   if (U > 1) {
162     for (int u = U - 2; u < U; ++u) {
163       (*log_beta_b)[u][T - 1] = TT(0);
164     }
165   } else {
166     (*log_beta_b)[0][T - 1] = TT(0);
167     (*log_beta_b)[0][T - 2] = TT(0);
168   }
169 
170   for (int t = T - 2; t >= 0; --t) {
171     int low = std::max(0, U - (2 * (T - t)));
172     int high = std::min(U, 2 * (t + 1));
173     for (int u = low; u < high; ++u) {
174       if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
175         (*log_beta_b)[u][t] =
176           LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u][t + 1] + TT(log(y[label_with_blank[u]][t + 1])));
177       }
178 
179       if (u + 1 < U) {
180         (*log_beta_b)[u][t] =
181           LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 1][t + 1] + TT(log(y[label_with_blank[u + 1]][t + 1])));
182       }
183 
184       if (u + 2 < U) {
185         bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u + 2]);
186         if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
187           (*log_beta_b)[u][t] =
188             LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 2][t + 1] + TT(log(y[label_with_blank[u + 2]][t + 1])));
189         }
190       }
191     }
192   }
193 }
194 
195 template <typename TT>
CalculateGrad(const std::vector<uint32_t> & label_with_blank,const std::vector<std::vector<TT>> & y,const std::vector<std::vector<TT>> & log_alpha_b,const std::vector<std::vector<TT>> & log_beta_b,const TT log_pzx,std::vector<std::vector<TT>> * dy) const196 void CTCLossCPUKernel::CalculateGrad(const std::vector<uint32_t> &label_with_blank,
197                                      const std::vector<std::vector<TT>> &y,
198                                      const std::vector<std::vector<TT>> &log_alpha_b,
199                                      const std::vector<std::vector<TT>> &log_beta_b, const TT log_pzx,
200                                      std::vector<std::vector<TT>> *dy) const {
201   auto dy_b = dy;
202   TT kLogZero_ = -std::numeric_limits<TT>::infinity();
203   if (log_pzx <= kLogZero_) {
204     MS_LOG(INFO) << "No valid path found";
205     return;
206   }
207 
208   size_t L = y.size();
209   size_t T = y[0].size();
210   size_t U = label_with_blank.size();
211 
212   for (size_t t = 0; t < T; ++t) {
213     std::vector<TT> prob_sum(L, kLogZero_);
214 
215     for (size_t u = 0; u < U; ++u) {
216       uint32_t l = label_with_blank[u];
217       prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]);
218     }
219     for (size_t l = 0; l < L; ++l) {
220       (*dy_b)[l][t] = y[l][t] - static_cast<TT>(exp(prob_sum[l] - log_pzx));
221     }
222   }
223 }
224 
GenLabelWithBlank(const uint32_t * seq_len,const std::vector<std::vector<uint32_t>> & batch_label,std::vector<std::vector<uint32_t>> * label_with_blank) const225 void CTCLossCPUKernel::GenLabelWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label,
226                                          std::vector<std::vector<uint32_t>> *label_with_blank) const {
227   for (size_t b = 0; b < batch_size_; ++b) {
228     std::vector<uint32_t> l;
229     const std::vector<uint32_t> &label = batch_label[b];
230     bool has_blank = false;
231     for (size_t i = 0; i < label.size(); ++i) {
232       if (i == 0 || !preprocess_collapse_repeated_ || label[i] != label[i - 1]) {
233         if (label[i] >= num_class_ - 1) {
234           has_blank = true;
235         } else {
236           if (has_blank) {
237             MS_LOG(EXCEPTION) << "Invalid labels(index >= num_class - 1) should not appear between two valid labels";
238           }
239           l.push_back(label[i]);
240         }
241       }
242     }
243     if (!ignore_longer_outputs_than_inputs_ && l.size() > seq_len[b]) {
244       MS_LOG(EXCEPTION) << "Input time(sequence length) should greater than output size(label length), but gets "
245                         << seq_len[b] << "< " << l.size();
246     }
247 
248     (*label_with_blank)[b].reserve(2 * l.size() + 1);
249     for (auto l_i : l) {
250       (*label_with_blank)[b].push_back(blank_index_);
251       (*label_with_blank)[b].push_back(l_i);
252     }
253     (*label_with_blank)[b].push_back(blank_index_);
254   }
255 }
256 
257 template <typename T>
LaunchKernel(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & outputs) const258 void CTCLossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
259                                     const std::vector<AddressPtr> &outputs) const {
260   const auto *inputs_addr = reinterpret_cast<T *>(inputs[0]->addr);
261   const auto *labels_indices_addr = reinterpret_cast<uint64_t *>(inputs[1]->addr);
262   const auto *labels_values_addr = reinterpret_cast<uint32_t *>(inputs[2]->addr);
263   const auto *sequence_length_addr = reinterpret_cast<uint32_t *>(inputs[3]->addr);
264   auto *loss_addr = reinterpret_cast<T *>(outputs[0]->addr);
265   auto *gradient_addr = reinterpret_cast<T *>(outputs[1]->addr);
266 
267   std::vector<std::vector<uint32_t>> label_batch;
268   std::vector<std::vector<uint32_t>> labels_with_blank;
269   std::vector<uint64_t> each_label_length;
270 
271   label_batch.resize(batch_size_);
272   labels_with_blank.resize(batch_size_);
273   each_label_length.resize(batch_size_, 0);
274 
275   T kLogZero_ = -std::numeric_limits<T>::infinity();
276   // check validation of sequence length
277   for (size_t b = 0; b < batch_size_; ++b) {
278     if (sequence_length_addr[b] == static_cast<uint32_t>(0)) {
279       MS_LOG(EXCEPTION) << "Sequence length should > 0, but gets " << sequence_length_addr[b];
280     }
281     if (sequence_length_addr[b] > max_time_) {
282       MS_LOG(EXCEPTION) << "Max time should be greater than sequence length, but gets " << max_time_ << " < "
283                         << sequence_length_addr[b];
284     }
285   }
286   for (size_t i = 0; i < indices_dims_[0]; ++i) {
287     const size_t factor = 2;
288     auto index = labels_indices_addr[i * factor];
289     if (index >= SizeToUlong(each_label_length.size())) {
290       MS_LOG(EXCEPTION) << "Index: " << index << "out of the bounds of the vector.";
291     }
292     each_label_length[index]++;
293   }
294 
295   // convert label format of label_value and label_indices to batch_label
296   uint64_t cum_sum = 0;
297   for (size_t b = 0; b < batch_size_; ++b) {
298     std::vector<uint32_t> *b_value = &label_batch[b];
299     for (size_t l = 0; l < each_label_length[b]; ++l) {
300       b_value->push_back(labels_values_addr[cum_sum + l]);
301     }
302     cum_sum += each_label_length[b];
303   }
304 
305   // convert label to label with blank
306   GenLabelWithBlank(sequence_length_addr, label_batch, &labels_with_blank);
307 
308   for (size_t b = 0; b < batch_size_; ++b) {
309     std::vector<uint32_t> label_with_blank = labels_with_blank[b];
310     // y_b [num_class, sequence_length]
311     std::vector<std::vector<T>> y_b;
312     std::vector<std::vector<T>> dy;
313     std::vector<std::vector<T>> log_alpha_b;
314     std::vector<std::vector<T>> log_beta_b;
315     MatrixFromVector(num_class_, sequence_length_addr[b], &y_b, kLogZero_);
316     MatrixFromVector(y_b.size(), y_b[0].size(), &dy, T(0));
317     MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_alpha_b, kLogZero_);
318     MatrixFromVector(label_with_blank.size(), sequence_length_addr[b], &log_beta_b, kLogZero_);
319     InnerSoftMax(inputs_addr, &y_b, sequence_length_addr[b], num_class_, batch_size_, b);
320     CalculateFwdVar(label_with_blank, y_b, &log_alpha_b);
321     CalculateBwdVar(label_with_blank, y_b, &log_beta_b);
322 
323     T log_pzx = kLogZero_;
324     for (size_t u = 0; u < label_with_blank.size(); ++u) {
325       log_pzx = LogSumExp(log_pzx, log_alpha_b[u][0] + log_beta_b[u][0]);
326     }
327     loss_addr[b] = -log_pzx;
328     CalculateGrad(label_with_blank, y_b, log_alpha_b, log_beta_b, log_pzx, &dy);
329 
330     for (size_t t = 0; t < sequence_length_addr[b]; ++t) {
331       for (size_t c = 0; c < num_class_; ++c) {
332         gradient_addr[t * batch_size_ * num_class_ + b * num_class_ + c] = dy[c][t];
333       }
334     }
335   }
336 }
337 }  // namespace kernel
338 }  // namespace mindspore
339