• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNIQUE_WITH_PAD_CPU_KERNEL_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNIQUE_WITH_PAD_CPU_KERNEL_H_
19 
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23 #include <map>
24 #include <functional>
25 #include "plugin/device/cpu/kernel/cpu_kernel.h"
26 #include "plugin/factory/ms_factory.h"
27 #include "plugin/device/cpu/kernel/unique_cpu_kernel.h"
28 #include "ops/op_utils.h"
29 
30 namespace mindspore {
31 namespace kernel {
32 inline static constexpr size_t kUniqueWithPadInputsNum = 2;
33 inline static constexpr size_t kUniqueWithPadOutputsNum = 2;
34 inline static constexpr size_t kPadNumIndex = 1;
35 inline static constexpr size_t kInputIndex = 0;
36 class UniqueWithPadCpuKernelMod : public UniqueCpuKernelMod {
37  public:
38   UniqueWithPadCpuKernelMod() = default;
39   ~UniqueWithPadCpuKernelMod() override = default;
40 
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)41   bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override {
42     dtype_ = inputs[0]->dtype_id();
43     auto batch_rank = ops::get_batch_rank(primitive_);
44     if (batch_rank < 0) {
45       return false;
46     }
47     batch_rank_ = static_cast<size_t>(batch_rank);
48     return true;
49   }
50 
51   int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
52 
53   bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
54               const std::vector<KernelTensor *> &outputs) override;
55 
GetOpSupport()56   std::vector<KernelAttr> GetOpSupport() override {
57     static std::vector<KernelAttr> support_list = {KernelAttr()
58                                                      .AddInputAttr(kNumberTypeInt32)
59                                                      .AddInputAttr(kNumberTypeInt32)
60                                                      .AddOutputAttr(kNumberTypeInt32)
61                                                      .AddOutputAttr(kNumberTypeInt32),
62                                                    KernelAttr()
63                                                      .AddInputAttr(kNumberTypeInt64)
64                                                      .AddInputAttr(kNumberTypeInt64)
65                                                      .AddOutputAttr(kNumberTypeInt64)
66                                                      .AddOutputAttr(kNumberTypeInt64),
67                                                    KernelAttr()
68                                                      .AddInputAttr(kNumberTypeFloat32)
69                                                      .AddInputAttr(kNumberTypeFloat32)
70                                                      .AddOutputAttr(kNumberTypeFloat32)
71                                                      .AddOutputAttr(kNumberTypeInt32)};
72     return support_list;
73   }
74 
75  private:
76   template <typename T>
77   void PadOutput(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs,
78                  const std::vector<size_t> &start);
79   // Disable update output shape because parent class 'UniqueCpuKernelMod'(for Unique op) need update output shape, but
80   // UniqueWithPad doesn't need.
IsNeedUpdateOutputShapeAndSize()81   bool IsNeedUpdateOutputShapeAndSize() override { return false; }
82 };
83 }  // namespace kernel
84 }  // namespace mindspore
85 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_UNIQUE_WITH_PAD_CPU_KERNEL_H_
86