• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
18 #define MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
19 
20 #include <vector>
21 #include <string>
22 #include <utility>
23 #include <memory>
24 #include "mindapi/base/base.h"
25 #include "mindapi/ir/common.h"
26 #include "mindapi/ir/anf.h"
27 #include "mindapi/ir/primitive.h"
28 #include "mindapi/ir/value.h"
29 #include "mindapi/ir/utils.h"
30 
31 namespace mindspore {
32 class FuncGraphManager;
33 }
34 
35 namespace mindspore::api {
36 /// \brief FuncGraph defines interface for a function graph.
37 class MIND_API FuncGraph : public Value {
38  public:
39   MIND_API_BASE_MEMBER(FuncGraph);
40 
41   /// \brief Get the input parameters.
42   ///
43   /// \return Input parameters of this graph.
44   std::vector<AnfNodePtr> get_inputs() const;
45 
46   /// \brief Get all parameters.
47   ///
48   /// \return All parameters of this graph.
49   std::vector<AnfNodePtr> parameters() const;
50 
51   /// \brief Adds a parameter to this graph.
52   ///
53   /// \param[in] p The parameter to be added.
54   void add_parameter(const ParameterPtr &p);
55 
56   /// \brief Adds a new parameter to this graph.
57   ///
58   /// \return The new added parameter.
59   ParameterPtr add_parameter();
60 
61   /// \brief Get the output node.
62   ///
63   /// \return The output node, nullptr if output not set.
64   AnfNodePtr output() const;
65 
66   /// \brief Get the return CNode.
67   ///
68   /// \return The return CNode, nullptr if no return node.
69   CNodePtr get_return() const;
70 
71   /// \brief Set the output node.
72   ///
73   /// \param[in] value The output node to be set.
74   /// \param[in] force_new_ret If true, a new return node is always created.
75   void set_output(const AnfNodePtr &value, bool force_new_ret = false);
76 
77   /// \brief Set the return node.
78   ///
79   /// \param[in] cnode The return CNode to be set.
80   void set_return(const CNodePtr &cnode);
81 
82   /// \brief Creates a new CNode in this graph.
83   ///
84   /// \param[in] inputs The input nodes of the new CNode.
85   ///
86   /// \return The created CNode.
87   CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
88 
89   /// \brief Creates a new primitive CNode in this graph.
90   ///
91   /// \param[in] primitive The primitive of the new CNode.
92   /// \param[in] prim_inputs The argument inputs of the primitive CNode.
93   ///
94   /// \return The created primitive CNode.
95   CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
96 
97   /// \brief Get all nodes in this graph.
98   ///
99   /// \return All nodes in this graph.
100   std::vector<AnfNodePtr> nodes() const;
101 
102   /// \brief Check whether an attribute is set for this graph.
103   ///
104   /// \param[in] key The attribute key (name).
105   ///
106   /// \return True if the attribute with the given key is set, false otherwise.
107   bool has_attr(const std::string &key) const;
108 
109   /// \brief Get an attribute value by its key.
110   ///
111   /// \param[in] key The attribute key (name).
112   ///
113   /// \return The attribute value for the given key, nullptr if attribute not found.
114   ValuePtr get_attr(const std::string &key) const;
115 
116   /// \brief Set an attribute value.
117   ///
118   /// \param[in] key The attribute key (name).
119   /// \param[in] value The attribute value.
120   void set_attr(const std::string &key, const ValuePtr &value);
121 
122   /// \brief Get the manager for this graph.
123   ///
124   /// \return The manager of this graph, nullptr if not set.
125   FuncGraphManagerPtr manager() const;
126 
127   /// \brief Creates an empty function graph.
128   ///
129   /// \return The created function graph.
130   static FuncGraphPtr Create();
131 
132   /// \brief Topological sort a graph from the given end node.
133   ///
134   /// \param[in] node The end node of the graph to be sorted.
135   ///
136   /// \return The sorted nodes.
137   static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node);
138 };
139 
140 /// \brief FuncGraphManager defines interface for function graph management.
141 class MIND_API FuncGraphManager {
142  public:
143   /// \brief Create FuncGraphManager with the given implementor object.
144   ///
145   /// \param[in] impl The pointer to the implementor object.
146   explicit FuncGraphManager(const std::shared_ptr<mindspore::FuncGraphManager> &impl);
147 
148   /// \brief Get the shared_ptr to the underly implementation object.
149   ///
150   /// \return The shared_ptr to the underly implementation object.
impl()151   const std::shared_ptr<mindspore::FuncGraphManager> &impl() const { return impl_; }
152 
153   /// \brief Replace an old node with a new node, related edges are all updated.
154   ///
155   /// \param[in] old_node The old node to be replaced.
156   /// \param[in] new_node The new node that replace the old one.
157   ///
158   /// \return True if the node is successfully replaced, false otherwise.
159   bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
160 
161   /// \brief Change an existed edge by replace its input node.
162   ///
163   /// \param[in] node The output node of the edge.
164   /// \param[in] index The input index in output node.
165   /// \param[in] value The new input node of the edge.
166   void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
167 
168   /// \brief Adds a new edge between the given two nodes.
169   ///
170   /// \param[in] node The output node of the edge.
171   /// \param[in] value The input node of the edge.
172   void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value);
173 
174   /// \brief Find users of the given node.
175   ///
176   /// \param[in] node The node.
177   ///
178   /// \return Users of the given node, empty if user not found.
179   std::vector<std::pair<AnfNodePtr, int>> GetUsers(const AnfNodePtr &node) const;
180 
181   /// \brief Manage the give function graph.
182   ///
183   /// \param[in] func_graph The function graph to be managed.
184   /// \param[in] manage If true, the created manager will be set in the graph.
185   ///
186   /// \return The manager that manages the given function graph.
187   static FuncGraphManagerPtr Manage(const FuncGraphPtr &func_graph, bool manage = true);
188 
189  private:
190   const std::shared_ptr<mindspore::FuncGraphManager> impl_;
191 };
192 }  // namespace mindspore::api
193 #endif  // MINDSPORE_CORE_MINDAPI_IR_FUNC_GRAPH_H_
194