• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
19 
20 #include <string>
21 #include <tuple>
22 #include <vector>
23 #include <memory>
24 #include "ir/anf.h"
25 #include "ir/func_graph.h"
26 #include "utils/hash_map.h"
27 #include "backend/common/graph_kernel/model/lite_graph.h"
28 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
29 
30 namespace mindspore::graphkernel {
31 constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
32 constexpr auto kAllTarget = "ALL";
33 constexpr auto kOutputsFormat = "outputs_format";
34 constexpr auto kAttrToPrim = "to_prim";
35 
36 using OpWithLevel = std::tuple<std::string, unsigned int, PrimitivePtr>;
37 
38 class GkUtils {
39  public:
40   /**
41    * @brief Extract kernel name from nodes, only the real kernel CNode is processed.
42    * @param[in] nodes The node list
43    * @param[in] prefix The prefix of result name
44    * @param[in] postfix The postfix of result name
45    * @return The string concatenated by the names of all cnodes
46    */
47   static std::string ExtractGraphKernelName(const AnfNodePtrList &nodes, const std::string &prefix = "",
48                                             const std::string &postfix = "");
49 
50   /**
51    * @brief Spread the MakeTuple in node list
52    * @param[in] nodes
53    * @param[in] begin_index
54    * @example
55    *   input
56    *     nodes: [ a, b, MakeTuple[i, j], c, d, MakeTuple[x, MakeTuple[y, z]] ]
57    *     begin_index: 1
58    *   output
59    *     [b, i, j, c, d, x, y, z]
60    * @return std::vector<AnfNodePtr>
61    */
62   static AnfNodePtrList SpreadTuples(const AnfNodePtrList &nodes, size_t begin_index = 0);
63 
64   /**
65    * @brief Filter operators by target, op level, and enable/disable flags.
66    * @param[in] ops_with_level the default operator list
67    * @param[in] level enabled op level
68    * @param[in] enable_ops_only the "enable_xxx_ops_only" flag
69    * @param[in] enable_ops the "enable_xxx_ops" flag
70    * @param[in] disable_ops the "disable_xxx_ops" flag
71    * @return Available primitive list
72    */
73   static std::vector<PrimitivePtr> GetValidOps(const std::vector<OpWithLevel> &ops_with_level, unsigned int level,
74                                                const std::vector<std::string> &enable_ops_only,
75                                                const std::vector<std::string> &enable_ops,
76                                                const std::vector<std::string> &disable_ops);
77 
78   /**
79    * @brief Filter excluded operators which can not be enabled from operator list.
80    * @param[in] ops the default operator list
81    * @return The operator list after removing the excluded operators
82    */
83   static std::vector<PrimitivePtr> FilterExcludedOps(const std::vector<PrimitivePtr> &ops);
84 
85   /**
86    * @brief Check whether graphkernel supports the node
87    */
88   static bool IsKeepBasicNode(const AnfNodePtr &node);
89 
90   /**
91    * @brief Create CNode.
92    */
93   static CNodePtr NewRealCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph,
94                                const std::vector<inner::NodeBase> &out_info_list, const CallbackPtr &cb);
95 
96   /**
97    * @brief Transform  LiteGraph to FuncGraph.
98    */
99   static FuncGraphPtr LiteGraph2AnfGraph(const inner::LiteGraphPtr &lite_graph, const CallbackPtr &cb);
100 
101   /**
102    * @brief Transform FuncGraph to LiteGraph
103    * @param[in] func_graph the anf FuncGraph
104    * @param[out] op_node_map store the relationship of inner node and anf node.
105    */
106   static inner::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph,
107                                                 HashMap<inner::NodePtr, AnfNodePtr> *op_node_map = nullptr);
108 
109   /**
110    * @brief Get manager of func graph. If there is no manager, a new one will be created.
111    */
112   static FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph);
113 
114   /**
115    * @brief Update func graph manager. Do nothing if the manager is nullptr.
116    */
117   static void UpdateFuncGraphManager(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph);
118 
119   /**
120    * @brief Get op's prim by op name.
121    */
122   static PrimitivePtr GetOpsPrim(const std::string &name);
123 
124   /**
125    * @brief Get the real kernel nodes from a FuncGraph, the parameters and output nodes are extracted if the input_list
126    * and output_list is set.
127    */
128   static void GetValidKernelNodes(const FuncGraphPtr &func_graph, AnfNodePtrList *node_list,
129                                   AnfNodePtrList *input_list = nullptr, AnfNodePtrList *output_list = nullptr);
130 
131   /**
132    * @brief Get n in NCHWnc. NCHWnc format is used in conv2dfusion expander in lite. n is the value of tiling channel.
133    * @example
134    *   input
135    *     NCHW16c
136    *   output
137    *     16
138    */
139   static int64_t GetChannelInConvFormat(const std::string &format_string);
140 
141   /**
142    * @brief Get all graphkernel's nodes in topo order.
143    */
144   static AnfNodePtrList GetGraphKernelNodes(const FuncGraphPtr &func_graph);
145 
146   static bool UseAkgCceLib(const AnfNodePtr &node);
147 };
148 }  // namespace mindspore::graphkernel
149 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_CORE_GRAPH_KERNEL_UTILS_H_
150