1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_ 19 20 #include <functional> 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "backend/kernel_compiler/cpu/cpu_kernel.h" 28 #include "backend/kernel_compiler/oplib/oplib.h" 29 #include "runtime/device/cpu/kernel_select_cpu.h" 30 #include "utils/ms_utils.h" 31 32 namespace mindspore { 33 namespace kernel { 34 using mindspore::device::cpu::KernelAttr; 35 using CPUKernelCreator = std::function<std::shared_ptr<CPUKernel>()>; 36 37 class CPUKernelFactory { 38 public: 39 static CPUKernelFactory &GetInstance(); 40 void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator); 41 std::shared_ptr<CPUKernel> Create(const std::string &kernel_name, const CNodePtr &apply_kernel); 42 void SetKernelAttrs(const std::shared_ptr<kernel::OpInfo> op_info, std::vector<KernelAttr> *kernel_attrs); 43 void UpdateKernelAttrs(const std::string &kernel_name, const std::vector<KernelAttr> &kernel_attrs); 44 std::vector<KernelAttr> GetSupportedKernelAttrList(const std::string &kernel_name); 45 46 private: 47 CPUKernelFactory() = default; 48 ~CPUKernelFactory() = default; 49 DISABLE_COPY_AND_ASSIGN(CPUKernelFactory) 50 std::pair<bool, size_t> CPUKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo &kernel_info); 51 bool CPUKernelSingleAttrCheck(const KernelAttr &kernel_attr, const KernelBuildInfo &kernel_info) const; 52 std::map<std::string, std::vector<std::pair<KernelAttr, CPUKernelCreator>>> name_to_attr_creator_; 53 }; 54 55 class CPUKernelRegistrar { 56 public: CPUKernelRegistrar(const std::string & kernel_name,const KernelAttr & kernel_attr,CPUKernelCreator && kernel_creator)57 CPUKernelRegistrar(const std::string &kernel_name, const KernelAttr &kernel_attr, CPUKernelCreator &&kernel_creator) { 58 CPUKernelFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(kernel_creator)); 59 } 60 ~CPUKernelRegistrar() = default; 61 }; 62 63 #define MS_REG_CPU_KERNEL(OPNAME, ATTR, OPCLASS) MS_REG_CPU_KERNEL_(__COUNTER__, OPNAME, ATTR, OPCLASS) 64 #define MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) 65 #define _MS_REG_CPU_KERNEL_(COUNT, OPNAME, ATTR, OPCLASS) \ 66 static_assert(std::is_base_of<CPUKernel, OPCLASS>::value, " must be base of CPUKernel"); \ 67 static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_reg(#OPNAME, ATTR, \ 68 []() { return std::make_shared<OPCLASS>(); }); 69 70 #define MS_REG_CPU_KERNEL_T(OPNAME, ATTR, OPCLASS, T) MS_REG_CPU_KERNEL_T_(__COUNTER__, OPNAME, ATTR, OPCLASS, T) 71 #define MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) 72 #define _MS_REG_CPU_KERNEL_T_(COUNT, OPNAME, ATTR, OPCLASS, T) \ 73 static_assert(std::is_base_of<CPUKernel, OPCLASS<T>>::value, " must be base of CPUKernel"); \ 74 static const CPUKernelRegistrar g_cpu_kernel_##COUNT##_##OPNAME##_##T##_reg( \ 75 #OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T>>(); }); 76 77 #define MS_REG_CPU_KERNEL_T_S(OPNAME, ATTR, OPCLASS, T, S) \ 78 static_assert(std::is_base_of<CPUKernel, OPCLASS<T, S>>::value, " must be base of CPUKernel"); \ 79 static const CPUKernelRegistrar g_cpu_kernel_##OPNAME##_##T##_##S##_reg( \ 80 #OPNAME, ATTR, []() { return std::make_shared<OPCLASS<T, S>>(); }); 81 } // namespace kernel 82 } // namespace mindspore 83 84 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_FACTORY_H_ 85