1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ 19 20 #include <string> 21 #include <tuple> 22 #include <vector> 23 #include <memory> 24 #include "ir/anf.h" 25 #include "ir/func_graph.h" 26 #include "utils/hash_map.h" 27 #include "backend/common/graph_kernel/model/lite_graph.h" 28 #include "backend/common/graph_kernel/core/graph_kernel_callback.h" 29 30 namespace mindspore::graphkernel { 31 constexpr auto kGraphKernelDumpPath = "graph_kernel_dump"; 32 constexpr auto kAllTarget = "ALL"; 33 constexpr auto kOutputsFormat = "outputs_format"; 34 constexpr auto kAttrToPrim = "to_prim"; 35 36 using OpWithLevel = std::tuple<std::string, unsigned int, PrimitivePtr>; 37 38 class GkUtils { 39 public: 40 /** 41 * @brief Extract kernel name from nodes, only the real kernel CNode is processed. 42 * @param[in] nodes The node list 43 * @param[in] prefix The prefix of result name 44 * @param[in] postfix The postfix of result name 45 * @return The string concatenated by the names of all cnodes 46 */ 47 static std::string ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix = "", 48 const std::string &postfix = ""); 49 50 /** 51 * @brief Spread the MakeTuple in node list 52 * @param[in] nodes 53 * @param[in] begin_index 54 * @example 55 * input 56 * nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ] 57 * begin_index: 1 58 * output 59 * [b, i, j, c, d, x, y, z] 60 * @return std::vector<AnfNodePtr> 61 */ 62 static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0); 63 64 /** 65 * @brief Filter operators by target, op level, and enable/disable flags. 66 * @param[in] ops_with_level the default operator list 67 * @param[in] level enabled op level 68 * @param[in] enable_ops_only the "enable_xxx_ops_only" flag 69 * @param[in] enable_ops the "enable_xxx_ops" flag 70 * @param[in] disable_ops the "disable_xxx_ops" flag 71 * @return Available primitive list 72 */ 73 static std::vector<PrimitivePtr> GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level, 74 const std::vector<std::string> &enable_ops_only, 75 const std::vector<std::string> &enable_ops, 76 const std::vector<std::string> &disable_ops); 77 78 /** 79 * @brief Filter excluded operators which can not be enabled from operator list. 80 * @param[in] ops the default operator list 81 * @return The operator list after removing the excluded operators 82 */ 83 static std::vector<PrimitivePtr> FilterExcludedOps(const std::vector<PrimitivePtr> &ops); 84 85 /** 86 * @brief Check whether graphkernel supports the node 87 */ 88 static bool IsKeepBasicNode(const AnfNodePtr &node); 89 90 /** 91 * @brief Create CNode. 92 */ 93 static CNodePtr NewRealCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, 94 const std::vector<inner::NodeBase> &out_info_list, const CallbackPtr &cb); 95 96 /** 97 * @brief Transform LiteGraph to FuncGraph. 98 */ 99 static FuncGraphPtr LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph, const CallbackPtr &cb); 100 101 /** 102 * @brief Transform FuncGraph to LiteGraph 103 * @param[in] func_graph the anf FuncGraph 104 * @param[out] op_node_map store the relationship of inner node and anf node. 105 */ 106 static inner::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph, 107 HashMap<inner::NodePtr, AnfNodePtr> *op_node_map = nullptr); 108 109 /** 110 * @brief Get manager of func graph. If there is no manager, a new one will be created. 111 */ 112 static FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph); 113 114 /** 115 * @brief Update func graph manager. Do nothing if the manager is nullptr. 116 */ 117 static void UpdateFuncGraphManager(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph); 118 119 /** 120 * @brief Get op's prim by op name. 121 */ 122 static PrimitivePtr GetOpsPrim(const std::string &name); 123 124 /** 125 * @brief Get the real kernel nodes from a FuncGraph, the parameters and output nodes are extracted if the input_list 126 * and output_list is set. 127 */ 128 static void GetValidKernelNodes(const FuncGraphPtr &func_graph, AnfNodePtrList *node_list, 129 AnfNodePtrList *input_list = nullptr, AnfNodePtrList *output_list = nullptr); 130 131 /** 132 * @brief Get n in NCHWnc. NCHWnc format is used in conv2dfusion expander in lite. n is the value of tiling channel. 133 * @example 134 * input 135 * NCHW16c 136 * output 137 * 16 138 */ 139 static int64_t GetChannelInConvFormat(const std::string &format_string); 140 141 /** 142 * @brief Get all graphkernel's nodes in topo order. 143 */ 144 static AnfNodePtrList GetGraphKernelNodes(const FuncGraphPtr &func_graph); 145 146 static bool UseAkgCceLib(const AnfNodePtr &node); 147 }; 148 } // namespace mindspore::graphkernel 149 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_ 150