• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_H
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_H
19 
20 #include <vector>
21 #include <utility>
22 #include <memory>
23 #include <map>
24 #include <string>
25 #include "backend/kernel_compiler/ascend_kernel_mod.h"
26 #include "backend/kernel_compiler/task_stream.h"
27 
28 namespace mindspore {
29 namespace kernel {
30 class RtKernel : public AscendKernelMod {
31  public:
32   RtKernel();
33   ~RtKernel() override;
34   virtual bool Init(const AnfNodePtr &anf_node);
35   const std::vector<size_t> &GetInputSizeList() const override;
36   const std::vector<size_t> &GetOutputSizeList() const override;
37   const std::vector<size_t> &GetWorkspaceSizeList() const override;
38 
39  protected:
40   mutable std::vector<size_t> input_size_list_;
41   mutable std::vector<size_t> output_size_list_;
42   mutable std::vector<size_t> workspace_size_list_;
43 };
44 
45 using RTKernelPtr = std::shared_ptr<RtKernel>;
46 
47 using RtKernelCreater = std::function<std::shared_ptr<RtKernel>()>;
48 class RtKernelFactory {
49   RtKernelFactory() = default;
50   ~RtKernelFactory() = default;
51 
52  public:
53   static RtKernelFactory &Get();
54   void Register(const std::string &name, RtKernelCreater &&fun);
55   static std::shared_ptr<RtKernel> Create(const std::string &name);
56 
57  private:
58   std::map<string, RtKernelCreater> fmap_;
59 };
60 
61 class _RtKernelRegister {
62  public:
_RtKernelRegister(const std::string & name,RtKernelCreater && fun)63   _RtKernelRegister(const std::string &name, RtKernelCreater &&fun) {
64     RtKernelFactory::Get().Register(name, std::move(fun));
65   }
66   ~_RtKernelRegister() = default;
67 };
68 
69 #define _MS_REG_RTKERNEL_REG(KNAME, clazz)                                             \
70   static_assert(std::is_base_of<RtKernel, clazz>::value, " must be base of RtKernel"); \
71   static const _RtKernelRegister g_##KNAME##_##_RtKernel_reg(#KNAME, []() { return std::make_shared<clazz>(); });
72 
73 #define MS_REG_RTKERNEL(KNAME, clazz) _MS_REG_RTKERNEL_REG(KNAME, clazz)
74 }  // namespace kernel
75 }  // namespace mindspore
76 
77 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_RTS_RT_KERNEL_H
78