1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_ 18 19 #include <memory> 20 #include <string> 21 #include <set> 22 #include <vector> 23 #include "runtime/device/ascend/kernel_select_ascend.h" 24 #include "backend/kernel_compiler/kernel_query.h" 25 #include "backend/kernel_compiler/oplib/oplib.h" 26 #include "backend/session/anf_runtime_algorithm.h" 27 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" 28 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" 29 30 namespace mindspore { 31 namespace opt { 32 class KernelSelect { 33 public: 34 KernelSelect() = default; 35 virtual ~KernelSelect() = default; SelectKernel(const CNodePtr & cnode)36 virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); } 37 }; 38 using KernelSelectPtr = std::shared_ptr<KernelSelect>; 39 40 class SupportedChecker { 41 public: 42 SupportedChecker() = default; 43 virtual ~SupportedChecker() = default; CheckAICoreSupported(const AnfNodePtr & anf_node,const kernel::KernelBuildInfoPtr & select_kernel_build_info)44 virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, 45 const kernel::KernelBuildInfoPtr &select_kernel_build_info) { 46 return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); 47 } CheckAICPUSupported(const AnfNodePtr & anf_node,const kernel::KernelBuildInfoPtr & select_kernel_build_info)48 virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, 49 const kernel::KernelBuildInfoPtr &select_kernel_build_info) { 50 return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); 51 } 52 }; 53 using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>; 54 55 class KernelQuery { 56 public: 57 KernelQuery() = default; 58 virtual ~KernelQuery() = default; Query(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list)59 virtual void Query(const CNodePtr &kernel_node, 60 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { 61 kernel::KernelQuery(kernel_node, kernel_info_list); 62 } IsTbeRef(const AnfNodePtr & node)63 virtual bool IsTbeRef(const AnfNodePtr &node) { 64 MS_EXCEPTION_IF_NULL(node); 65 if (!node->isa<CNode>()) { 66 return false; 67 } 68 auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(node), node); 69 if (op_info != nullptr) { 70 return op_info->is_ref(); 71 } 72 return false; 73 } 74 }; 75 using KernelQueryPtr = std::shared_ptr<KernelQuery>; 76 77 class TbeKernelQuery { 78 public: 79 TbeKernelQuery() = default; 80 virtual ~TbeKernelQuery() = default; GetTbeKernelMetaInfo(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list)81 virtual void GetTbeKernelMetaInfo(const CNodePtr &kernel_node, 82 std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { 83 kernel::TbeMetadataInfo(kernel_node, kernel_info_list); 84 } 85 }; 86 using TbeKernelQueryPtr = std::shared_ptr<TbeKernelQuery>; 87 88 class OpFinder { 89 public: 90 OpFinder() = default; 91 virtual ~OpFinder() = default; GetOpRegisteredOutputNum(const std::string & op_name,const CNodePtr & cnode)92 virtual int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) { 93 auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode); 94 if (op_info == nullptr) { 95 return -1; 96 } 97 return op_info->outputs_ptr().size(); 98 } 99 }; 100 using OpFinderPtr = std::shared_ptr<OpFinder>; 101 102 void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, 103 const AnfNodePtr &trans_data, const std::string &reshape_type = {""}, 104 const TypeId &type_id = kTypeUnknown); 105 106 CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, 107 const bool need_padding, const std::string &op_name, 108 const std::vector<int64_t> &perm = std::vector<int64_t>{}); 109 110 CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, 111 const TypeId &input_type, const TypeId &output_type, 112 const abstract::BaseShapePtr &origin_shape, const TypeId &origin_type, 113 const std::string &reshape_type = std::string{}); 114 115 AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 116 const KernelSelectPtr &kernel_select); 117 118 AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node, 119 const KernelSelectPtr &kernel_select); 120 121 CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); 122 123 AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node); 124 125 AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 126 const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input); 127 128 const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; 129 } // namespace opt 130 } // namespace mindspore 131 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_ 132