• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_INFO_H
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_INFO_H
19 
20 #include <memory>
21 #include <functional>
22 #include <map>
23 #include <string>
24 #include <set>
25 #include <vector>
26 #include <utility>
27 
28 #include "ir/dtype.h"
29 #include "backend/kernel_compiler/kernel_build_info.h"
30 #include "backend/kernel_compiler/kernel.h"
31 #include "utils/utils.h"
32 
33 namespace mindspore {
34 namespace kernel {
35 class RtKerDesc {
36  public:
~RtKerDesc()37   virtual ~RtKerDesc() {}
GetKernelInfo()38   virtual std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetKernelInfo() {
39     return std::vector<std::shared_ptr<kernel::KernelBuildInfo>>{};
40   }
41 };
42 
43 using RtKerDescCreater = std::function<std::shared_ptr<RtKerDesc>()>;
44 class RtKerDescFactory {
45   RtKerDescFactory() = default;
46   ~RtKerDescFactory() = default;
47 
48  public:
49   static RtKerDescFactory &Get();
50   void Register(const std::string &name, RtKerDescCreater &&fun);
51   static std::shared_ptr<RtKerDesc> Create(const std::string &name);
52 
53  private:
54   std::map<std::string, RtKerDescCreater> fmap_;
55 };
56 
57 class _RtKerDescRegister {
58  public:
_RtKerDescRegister(const std::string & name,RtKerDescCreater && fun)59   _RtKerDescRegister(const std::string &name, RtKerDescCreater &&fun) {
60     RtKerDescFactory::Get().Register(name, std::move(fun));
61   }
62   ~_RtKerDescRegister() = default;
63 };
64 
65 #define _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz)                                          \
66   static_assert(std::is_base_of<RtKerDesc, clazz>::value, " must be base of RtKerDesc"); \
67   static const _RtKerDescRegister g_##KNAME##_##_rtkernel_desc_reg(#KNAME, []() { return std::make_shared<clazz>(); });
68 
69 #define MS_REG_RTKERNEL_DESC(KNAME, clazz) _MS_REG_RTKERNEL_DESC_REG(KNAME, clazz)
70 
71 void GetRtKelInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list);
72 }  // namespace kernel
73 }  // namespace mindspore
74 
75 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_INFO_H
76