1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_ 18 #define MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_ 19 20 #include <vector> 21 #include <string> 22 #include <utility> 23 #include <memory> 24 #include "mindapi/base/base.h" 25 #include "mindapi/ir/common.h" 26 #include "mindapi/ir/anf.h" 27 #include "mindapi/ir/primitive.h" 28 #include "mindapi/ir/value.h" 29 #include "mindapi/ir/utils.h" 30 31 namespace mindspore { 32 class FuncGraphManager; 33 } 34 35 namespace mindspore::api { 36 /// \brief FuncGraph defines interface for a function graph. 37 class MIND_API FuncGraph : public Value { 38 public: 39 MIND_API_BASE_MEMBER(FuncGraph); 40 41 /// \brief Get the input parameters. 42 /// 43 /// \return Input parameters of this graph. 44 std::vector<AnfNodePtr> get_inputs() const; 45 46 /// \brief Get all parameters. 47 /// 48 /// \return All parameters of this graph. 49 std::vector<AnfNodePtr> parameters() const; 50 51 /// \brief Adds a parameter to this graph. 52 /// 53 /// \param[in] p The parameter to be added. 54 void add_parameter(const ParameterPtr &p); 55 56 /// \brief Adds a new parameter to this graph. 57 /// 58 /// \return The new added parameter. 59 ParameterPtr add_parameter(); 60 61 /// \brief Get the output node. 62 /// 63 /// \return The output node, nullptr if output not set. 64 AnfNodePtr output() const; 65 66 /// \brief Get the return CNode. 67 /// 68 /// \return The return CNode, nullptr if no return node. 69 CNodePtr get_return() const; 70 71 /// \brief Set the output node. 72 /// 73 /// \param[in] value The output node to be set. 74 /// \param[in] force_new_ret If true, a new return node is always created. 75 void set_output(const AnfNodePtr &value, bool force_new_ret = false); 76 77 /// \brief Set the return node. 78 /// 79 /// \param[in] cnode The return CNode to be set. 80 void set_return(const CNodePtr &cnode); 81 82 /// \brief Creates a new CNode in this graph. 83 /// 84 /// \param[in] inputs The input nodes of the new CNode. 85 /// 86 /// \return The created CNode. 87 CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()); 88 89 /// \brief Creates a new primitive CNode in this graph. 90 /// 91 /// \param[in] primitive The primitive of the new CNode. 92 /// \param[in] prim_inputs The argument inputs of the primitive CNode. 93 /// 94 /// \return The created primitive CNode. 95 CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs); 96 97 /// \brief Get all nodes in this graph. 98 /// 99 /// \return All nodes in this graph. 100 std::vector<AnfNodePtr> nodes() const; 101 102 /// \brief Check whether an attribute is set for this graph. 103 /// 104 /// \param[in] key The attribute key (name). 105 /// 106 /// \return True if the attribute with the given key is set, false otherwise. 107 bool has_attr(const std::string &key) const; 108 109 /// \brief Get an attribute value by its key. 110 /// 111 /// \param[in] key The attribute key (name). 112 /// 113 /// \return The attribute value for the given key, nullptr if attribute not found. 114 ValuePtr get_attr(const std::string &key) const; 115 116 /// \brief Set an attribute value. 117 /// 118 /// \param[in] key The attribute key (name). 119 /// \param[in] value The attribute value. 120 void set_attr(const std::string &key, const ValuePtr &value); 121 122 /// \brief Get the manager for this graph. 123 /// 124 /// \return The manager of this graph, nullptr if not set. 125 FuncGraphManagerPtr manager() const; 126 127 /// \brief Creates an empty function graph. 128 /// 129 /// \return The created function graph. 130 static FuncGraphPtr Create(); 131 132 /// \brief Topological sort a graph from the given end node. 133 /// 134 /// \param[in] node The end node of the graph to be sorted. 135 /// 136 /// \return The sorted nodes. 137 static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node); 138 }; 139 140 /// \brief FuncGraphManager defines interface for function graph management. 141 class MIND_API FuncGraphManager { 142 public: 143 /// \brief Create FuncGraphManager with the given implementor object. 144 /// 145 /// \param[in] impl The pointer to the implementor object. 146 explicit FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl); 147 148 /// \brief Get the shared_ptr to the underly implementation object. 149 /// 150 /// \return The shared_ptr to the underly implementation object. impl()151 const std::shared_ptr<mindspore::FuncGraphManager> &impl() const { return impl_; } 152 153 /// \brief Replace an old node with a new node, related edges are all updated. 154 /// 155 /// \param[in] old_node The old node to be replaced. 156 /// \param[in] new_node The new node that replace the old one. 157 /// 158 /// \return True if the node is successfully replaced, false otherwise. 159 bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); 160 161 /// \brief Change an existed edge by replace its input node. 162 /// 163 /// \param[in] node The output node of the edge. 164 /// \param[in] index The input index in output node. 165 /// \param[in] value The new input node of the edge. 166 void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); 167 168 /// \brief Adds a new edge between the given two nodes. 169 /// 170 /// \param[in] node The output node of the edge. 171 /// \param[in] value The input node of the edge. 172 void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value); 173 174 /// \brief Find users of the given node. 175 /// 176 /// \param[in] node The node. 177 /// 178 /// \return Users of the given node, empty if user not found. 179 std::vector<std::pair<AnfNodePtr, int>> GetUsers(const AnfNodePtr &node) const; 180 181 /// \brief Manage the give function graph. 182 /// 183 /// \param[in] func_graph The function graph to be managed. 184 /// \param[in] manage If true, the created manager will be set in the graph. 185 /// 186 /// \return The manager that manages the given function graph. 187 static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true); 188 189 private: 190 const std::shared_ptr<mindspore::FuncGraphManager> impl_; 191 }; 192 } // namespace mindspore::api 193 #endif // MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_ 194