1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_ 17 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_ 18 19 #include <functional> 20 #include <map> 21 #include <string> 22 #include <vector> 23 #include <utility> 24 #include <memory> 25 #include "plugin/device/gpu/kernel/gpu_kernel.h" 26 #include "plugin/device/gpu/hal/device/kernel_info_setter.h" 27 #include "kernel/kernel_build_info.h" 28 #include "kernel/common_utils.h" 29 30 namespace mindspore { 31 namespace kernel { 32 using NativeGpuKernelModCreater = std::function<NativeGpuKernelMod *()>; 33 class NativeGpuKernelModFactory { 34 public: 35 ~NativeGpuKernelModFactory() = default; 36 37 static NativeGpuKernelModFactory &GetInstance(); 38 39 void Register(const std::string &kernel_name, const KernelAttr &kernel_attr, NativeGpuKernelModCreater &&creator); 40 41 NativeGpuKernelMod *Create(const std::string &kernel_name, const CNodePtr &apply_kernel); 42 43 bool SearchRegistered(const std::string &kernel_name, const KernelBuildInfoPtr &kernel_info); 44 45 std::string SupportedTypeList(const std::string &kernel_name); 46 47 std::vector<KernelAttr> GetGpuSupportedList(const std::string &kernel_name); 48 49 // Judge whether is registered kernel. 50 bool IsRegistered(const std::string &kernel_name); 51 52 bool ReducePrecision(const std::string &kernel_name, 53 std::shared_ptr<mindspore::kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder); 54 55 std::pair<std::vector<size_t>, TypeId> reduce_flag_{{}, kNumberTypeInt64}; 56 57 private: 58 NativeGpuKernelModFactory() = default; 59 60 NativeGpuKernelModFactory(NativeGpuKernelModFactory const &); 61 62 NativeGpuKernelModFactory &operator=(const NativeGpuKernelModFactory &); 63 64 std::pair<bool, size_t> GpuKernelAttrCheck(const std::string &kernel_name, const KernelBuildInfo *kernel_info); 65 void CheckSM(const KernelBuildInfo *kernel_info, const size_t &input_index); 66 bool CheckIOParam(const std::string &kernel_name, const KernelBuildInfo *kernel_info, 67 std::vector<std::pair<KernelAttr, NativeGpuKernelModCreater>> *iter_second, size_t attr_index); 68 69 // Set output and input ref map to kernel info which will be used by graph compiler. 70 void SetRefMapToKernelInfo(const std::string &kernel_name, size_t index, device::KernelInfo *kernel_info); 71 72 // map to maintain kernel and creator, KernelAttr object and creator must be registered as a pair. 73 std::map<std::string, std::vector<std::pair<KernelAttr, NativeGpuKernelModCreater>>> map_kernel_name_to_creater_; 74 }; 75 76 class GpuKernelRegister { 77 public: GpuKernelRegister(const std::string & kernel_name,const KernelAttr & kernel_attr,NativeGpuKernelModCreater && creator)78 GpuKernelRegister(const std::string &kernel_name, const KernelAttr &kernel_attr, 79 NativeGpuKernelModCreater &&creator) { 80 NativeGpuKernelModFactory::GetInstance().Register(kernel_name, kernel_attr, std::move(creator)); 81 } 82 ~GpuKernelRegister() = default; 83 }; 84 85 #define UNIQUE_KERNEL_NAME(kernel) KERNEL_NAME(g_##kernel##_gpu_kernel_reg, __COUNTER__) 86 #define KERNEL_NAME(kernel, cnt) MERGE(kernel, cnt) 87 #define MERGE(kernel, cnt) kernel##cnt 88 89 #define MS_REG_GPU_KERNEL(OPNAME, OPCLASS) \ 90 static_assert(std::is_base_of<NativeGpuKernelMod, OPCLASS>::value, " must be base of NativeGpuKernelMod"); \ 91 static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, KernelAttr(), []() { return new OPCLASS(); }); 92 93 // regular register of fixed accuracy kernels 94 #define MS_REG_GPU_KERNEL_REGULAR(OPNAME, ATTR, OPCLASS) \ 95 static_assert(std::is_base_of<NativeGpuKernelMod, OPCLASS>::value, " must be base of NativeGpuKernelMod"); \ 96 static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS(); }); 97 98 // register of mixed accuracy kernels which use template and maintain one typename, ignore input num 99 #define MS_REG_GPU_KERNEL_SAME(OPNAME, ATTR, OPCLASS, T) \ 100 static_assert(std::is_base_of<NativeGpuKernelMod, OPCLASS<T>>::value, " must be base of NativeGpuKernelMod"); \ 101 static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T>(); }); 102 103 // register of mixed accuracy kernels which use template and maintain one typename 104 #define MS_REG_GPU_KERNEL_ONE(OPNAME, ATTR, OPCLASS, T) \ 105 static_assert(std::is_base_of<NativeGpuKernelMod, OPCLASS<T>>::value, " must be base of NativeGpuKernelMod"); \ 106 static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T>(); }); 107 108 // register of mixed accuracy kernels which use template and maintain two typename 109 #define MS_REG_GPU_KERNEL_TWO(OPNAME, ATTR, OPCLASS, T, S) \ 110 static_assert(std::is_base_of<NativeGpuKernelMod, OPCLASS<T, S>>::value, " must be base of NativeGpuKernelMod"); \ 111 static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T, S>(); }); 112 113 // register of mixed accuracy kernels which use template and maintain three typename 114 #define MS_REG_GPU_KERNEL_THREE(OPNAME, ATTR, OPCLASS, T, S, G) \ 115 static_assert(std::is_base_of<NativeGpuKernelMod, OPCLASS<T, S, G>>::value, " must be base of NativeGpuKernelMod"); \ 116 static const GpuKernelRegister UNIQUE_KERNEL_NAME(OPNAME)(#OPNAME, ATTR, []() { return new OPCLASS<T, S, G>(); }); 117 } // namespace kernel 118 } // namespace mindspore 119 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_GPUKERNELFACTORY_H_ 120