• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_
18 
19 #include <memory>
20 #include <string>
21 #include <set>
22 #include <vector>
23 #include "runtime/device/ascend/kernel_select_ascend.h"
24 #include "backend/kernel_compiler/kernel_query.h"
25 #include "backend/kernel_compiler/oplib/oplib.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
28 #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
29 
30 namespace mindspore {
31 namespace opt {
32 class KernelSelect {
33  public:
34   KernelSelect() = default;
35   virtual ~KernelSelect() = default;
SelectKernel(const CNodePtr & cnode)36   virtual void SelectKernel(const CNodePtr &cnode) { device::ascend::SelectKernelInfo(cnode); }
37 };
38 using KernelSelectPtr = std::shared_ptr<KernelSelect>;
39 
40 class SupportedChecker {
41  public:
42   SupportedChecker() = default;
43   virtual ~SupportedChecker() = default;
CheckAICoreSupported(const AnfNodePtr & anf_node,const kernel::KernelBuildInfoPtr & select_kernel_build_info)44   virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node,
45                                     const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
46     return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info);
47   }
CheckAICPUSupported(const AnfNodePtr & anf_node,const kernel::KernelBuildInfoPtr & select_kernel_build_info)48   virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node,
49                                    const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
50     return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info);
51   }
52 };
53 using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
54 
55 class KernelQuery {
56  public:
57   KernelQuery() = default;
58   virtual ~KernelQuery() = default;
Query(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list)59   virtual void Query(const CNodePtr &kernel_node,
60                      std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
61     kernel::KernelQuery(kernel_node, kernel_info_list);
62   }
IsTbeRef(const AnfNodePtr & node)63   virtual bool IsTbeRef(const AnfNodePtr &node) {
64     MS_EXCEPTION_IF_NULL(node);
65     if (!node->isa<CNode>()) {
66       return false;
67     }
68     auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(node), node);
69     if (op_info != nullptr) {
70       return op_info->is_ref();
71     }
72     return false;
73   }
74 };
75 using KernelQueryPtr = std::shared_ptr<KernelQuery>;
76 
77 class TbeKernelQuery {
78  public:
79   TbeKernelQuery() = default;
80   virtual ~TbeKernelQuery() = default;
GetTbeKernelMetaInfo(const CNodePtr & kernel_node,std::vector<std::shared_ptr<kernel::KernelBuildInfo>> * kernel_info_list)81   virtual void GetTbeKernelMetaInfo(const CNodePtr &kernel_node,
82                                     std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) {
83     kernel::TbeMetadataInfo(kernel_node, kernel_info_list);
84   }
85 };
86 using TbeKernelQueryPtr = std::shared_ptr<TbeKernelQuery>;
87 
88 class OpFinder {
89  public:
90   OpFinder() = default;
91   virtual ~OpFinder() = default;
GetOpRegisteredOutputNum(const std::string & op_name,const CNodePtr & cnode)92   virtual int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) {
93     auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
94     if (op_info == nullptr) {
95       return -1;
96     }
97     return op_info->outputs_ptr().size();
98   }
99 };
100 using OpFinderPtr = std::shared_ptr<OpFinder>;
101 
102 void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
103                             const AnfNodePtr &trans_data, const std::string &reshape_type = {""},
104                             const TypeId &type_id = kTypeUnknown);
105 
106 CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
107                         const bool need_padding, const std::string &op_name,
108                         const std::vector<int64_t> &perm = std::vector<int64_t>{});
109 
110 CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
111                               const TypeId &input_type, const TypeId &output_type,
112                               const abstract::BaseShapePtr &origin_shape, const TypeId &origin_type,
113                               const std::string &reshape_type = std::string{});
114 
115 AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
116                                  const KernelSelectPtr &kernel_select);
117 
118 AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &orig_node, const AnfNodePtr &node,
119                                   const KernelSelectPtr &kernel_select);
120 
121 CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
122 
123 AnfNodePtr CreateTensorMoveOp(const FuncGraphPtr &graph, const AnfNodePtr &node);
124 
125 AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
126                                  const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input);
127 
128 const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW};
129 }  // namespace opt
130 }  // namespace mindspore
131 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_HELPER_H_
132