• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_MINDAPI_IR_ANF_H_
18 #define MINDSPORE_CORE_MINDAPI_IR_ANF_H_
19 
20 #include <vector>
21 #include <string>
22 #include "mindapi/base/base.h"
23 #include "mindapi/ir/common.h"
24 #include "mindapi/ir/abstract.h"
25 #include "mindapi/ir/primitive.h"
26 #include "mindapi/ir/value.h"
27 
28 namespace mindspore::api {
29 /// \brief AnfNode is the basic class of the IR graph node.
30 class MIND_API AnfNode : public Base {
31  public:
32   MIND_API_BASE_MEMBER(AnfNode);
33 
34   /// \brief Obtain detailed information about scope namespace.
35   ///
36   /// \return Detailed information about scope namespace.
37   std::string fullname_with_scope() const;
38 
39   /// \brief Obtain the inferred abstract value of this AnfNode.
40   ///
41   /// \return The inferred abstract value.
42   AbstractBasePtr abstract() const;
43 
44   /// \brief Set the abstract value of this AnfNode.
45   ///
46   /// \param[in] abs New abstract value.
47   void set_abstract(const AbstractBasePtr &abs);
48 };
49 
50 /// \brief CNode represents a compute node with a set of input nodes.
51 class MIND_API CNode : public AnfNode {
52  public:
53   MIND_API_BASE_MEMBER(CNode);
54 
55   /// \brief Get the number of inputs.
56   ///
57   /// \return The number of inputs in this CNode.
58   size_t size() const;
59 
60   /// \brief Get the input node of the given index.
61   ///
62   /// \param[in] i The given index.
63   ///
64   /// \return The input node of the given index.
65   AnfNodePtr input(size_t i) const;
66 
67   /// \brief Get the input nodes.
68   ///
69   /// \return The input nodes of this CNode.
70   std::vector<AnfNodePtr> inputs() const;
71 
72   /// \brief Set the input nodes for this CNode.
73   ///
74   /// \param[in] inputs Input nodes.
75   void set_inputs(const std::vector<AnfNodePtr> &inputs);
76 
77   /// \brief Add an input node to this CNode.
78   ///
79   /// \param[in] input the input node to be added.
80   void add_input(const AnfNodePtr &input);
81 
82   /// \brief Set fullname_with_scope for this CNode.
83   ///
84   /// \param[in] full_name The fullname_with_scope.
85   void set_fullname_with_scope(const std::string &full_name);
86 
87   /// \brief Add a new attribute to this CNode.
88   ///
89   /// \param[in] name The name of the new attribute.
90   /// \param[in] attr The value of the new attribute.
91   void AddAttr(const std::string &name, const ValuePtr &attr);
92 
93   /// \brief Erase the attribute with the given name.
94   ///
95   /// \param[in] name The name of attribute.
96   void EraseAttr(const std::string &name);
97 
98   /// \brief Get the attribute with the given name.
99   ///
100   /// \param[in] name The name of attribute.
101   /// \return Attribute.
102   ValuePtr GetAttr(const std::string &name) const;
103 };
104 
105 using CNodePtr = SharedPtr<CNode>;
106 
107 /// \brief Parameter represents the parameter inputs of a function.
108 class MIND_API Parameter : public AnfNode {
109  public:
110   MIND_API_BASE_MEMBER(Parameter);
111 
112   /// \brief Get the name of this Parameter.
113   ///
114   /// \return The name.
115   std::string name() const;
116 
117   /// \brief Set the name of this Parameter.
118   ///
119   /// \param[in] name The name.
120   void set_name(const std::string &name);
121 
122   /// \brief Check if there is a default parameter.
123   ///
124   /// \return True if this Parameter has a default parameter, otherwise false.
125   bool has_default() const;
126 
127   /// \brief Set the default parameter.
128   ///
129   /// \param[in] param The default parameter.
130   void set_default_param(const ValuePtr &param);
131 
132   /// \brief Get the default parameter.
133   ///
134   /// \return The default parameter.
135   ValuePtr default_param() const;
136 };
137 
138 using ParameterPtr = SharedPtr<Parameter>;
139 
140 /// \brief ValueNode is a graph node that hold a value.
141 class MIND_API ValueNode : public AnfNode {
142  public:
143   MIND_API_BASE_MEMBER(ValueNode);
144 
145   /// \brief Create ValueNode with the given value.
146   ///
147   /// \param[in] value The value of this ValueNode.
148   explicit ValueNode(const ValuePtr &value);
149 
150   /// \brief Get the value of this ValueNode.
151   ///
152   /// \return The value.
153   ValuePtr value() const;
154 };
155 
156 using ValueNodePtr = SharedPtr<ValueNode>;
157 
158 // === ANF utility functions === //
159 
160 /// \brief Create a ValueNode with the given value.
161 ///
162 /// \param[in] value The given value.
163 ///
164 /// \return The created ValueNode.
165 template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Value, T>, T>>
NewValueNode(const SharedPtr<T> & value)166 inline ValueNodePtr NewValueNode(const SharedPtr<T> &value) {
167   return MakeShared<ValueNode>(value);
168 }
169 
170 /// \brief Create a ValueNode with the given primitive type value.
171 ///
172 /// \param[in] value The given primitive type value.
173 ///
174 /// \return The created ValueNode.
175 template <typename T>
NewValueNode(T value)176 inline ValueNodePtr NewValueNode(T value) {
177   return NewValueNode(MakeValue(value));
178 }
179 
180 /// \brief Get the value from a node if it is a ValueNode.
181 ///
182 /// \param[in] node The node which may hold a value.
183 ///
184 /// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set.
GetValueNode(const AnfNodePtr & node)185 inline ValuePtr GetValueNode(const AnfNodePtr &node) {
186   if (node == nullptr) {
187     return nullptr;
188   }
189   auto value_node = node->cast<ValueNodePtr>();
190   if (value_node == nullptr) {
191     return nullptr;
192   }
193   return value_node->value();
194 }
195 
196 /// \brief Get the value with the given type from a node if it is a ValueNode.
197 ///
198 /// \param[in] node The node which may hold a value.
199 ///
200 /// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set, or value type is mismatch.
201 template <typename T, typename = typename std::enable_if_t<
202                         is_wrapper_ptr<T>::value && std::is_base_of_v<Value, typename T::element_type>, T>>
GetValueNode(const AnfNodePtr & node)203 inline T GetValueNode(const AnfNodePtr &node) {
204   auto value = GetValueNode(node);
205   if (value == nullptr) {
206     return nullptr;
207   }
208   return value->cast<T>();
209 }
210 
211 /// \brief Check whether the given node is a cnode with the given Primitive as the first input.
212 ///
213 /// \param[in] node The given node to be checked.
214 /// \param[in] prim The Primitive value, nullptr means match any Primitive.
215 ///
216 /// \return True if the node is cnode and the first input is the given Primitive, false otherwise.
217 MIND_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim = nullptr);
218 
219 /// \brief Check whether the given node is a ValueNode with the given Primitive.
220 ///
221 /// \param[in] node The given node to be checked.
222 /// \param[in] prim The Primitive value.
223 ///
224 /// \return True if the given node is a ValueNode with the given Primitive, false otherwise.
225 MIND_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim);
226 
227 /// \brief Check if a node is a data node.
228 /// Some nodes may be used internally to pass some non-data states, those nodes are not data nodes.
229 ///
230 /// \param[in] node The node to be checked.
231 ///
232 /// \return True if the node is a data node, false otherwise.
233 MIND_API bool IsDataNode(const AnfNodePtr &node);
234 }  // namespace mindspore::api
235 #endif  // MINDSPORE_CORE_MINDAPI_IR_ANF_H_
236