• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_
19 
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23 #include <algorithm>
24 #include <limits>
25 #include <map>
26 #include "plugin/device/cpu/kernel/cpu_kernel.h"
27 #include "plugin/factory/ms_factory.h"
28 
29 namespace mindspore {
30 namespace kernel {
31 class CTCLossCpuKernelMod : public NativeCpuKernelMod {
32  public:
33   CTCLossCpuKernelMod() = default;
34   ~CTCLossCpuKernelMod() override = default;
35 
36   bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
37 
38   int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
39 
40   bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
41               const std::vector<KernelTensor *> &outputs) override;
42 
43   std::vector<KernelAttr> GetOpSupport() override;
44 
45  private:
46   void GenLabelWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label,
47                          std::vector<std::vector<uint32_t>> *label_with_blank) const;
48 
49   template <typename T>
50   void CalculateFwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y,
51                        std::vector<std::vector<T>> *log_alpha_b) const;
52   template <typename T>
53   void CalculateBwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y,
54                        std::vector<std::vector<T>> *log_beta_b) const;
55   template <typename T>
56   void CalculateGrad(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y,
57                      const std::vector<std::vector<T>> &log_alpha_b, const std::vector<std::vector<T>> &log_beta_b,
58                      const T log_pzx, std::vector<std::vector<T>> *dy) const;
59 
60   template <typename T>
61   void LaunchKernel(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) const;
62 
63   ShapeVector probs_shape_;
64   ShapeVector indices_dims_;
65   ShapeVector labels_dims_;
66   size_t num_class_{0};
67   size_t max_time_{0};
68   size_t batch_size_{0};
69   uint32_t blank_index_{0};
70   TypeId dtype_{kTypeUnknown};
71   bool preprocess_collapse_repeated_{false};
72   bool ctc_merge_repeated_{false};
73   bool ignore_longer_outputs_than_inputs_{false};
74 };
75 }  // namespace kernel
76 }  // namespace mindspore
77 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_
78