1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CORE_MINDAPI_IR_ANF_H_
18 #define MINDSPORE_CORE_MINDAPI_IR_ANF_H_
19
20 #include <vector>
21 #include <string>
22 #include "mindapi/base/base.h"
23 #include "mindapi/ir/common.h"
24 #include "mindapi/ir/abstract.h"
25 #include "mindapi/ir/primitive.h"
26 #include "mindapi/ir/value.h"
27
28 namespace mindspore::api {
29 /// \brief AnfNode is the basic class of the IR graph node.
30 class MIND_API AnfNode : public Base {
31 public:
32 MIND_API_BASE_MEMBER(AnfNode);
33
34 /// \brief Obtain detailed information about scope namespace.
35 ///
36 /// \return Detailed information about scope namespace.
37 std::string fullname_with_scope() const;
38
39 /// \brief Obtain the inferred abstract value of this AnfNode.
40 ///
41 /// \return The inferred abstract value.
42 AbstractBasePtr abstract() const;
43
44 /// \brief Set the abstract value of this AnfNode.
45 ///
46 /// \param[in] abs New abstract value.
47 void set_abstract(const AbstractBasePtr &abs);
48 };
49
50 /// \brief CNode represents a compute node with a set of input nodes.
51 class MIND_API CNode : public AnfNode {
52 public:
53 MIND_API_BASE_MEMBER(CNode);
54
55 /// \brief Get the number of inputs.
56 ///
57 /// \return The number of inputs in this CNode.
58 size_t size() const;
59
60 /// \brief Get the input node of the given index.
61 ///
62 /// \param[in] i The given index.
63 ///
64 /// \return The input node of the given index.
65 AnfNodePtr input(size_t i) const;
66
67 /// \brief Get the input nodes.
68 ///
69 /// \return The input nodes of this CNode.
70 std::vector<AnfNodePtr> inputs() const;
71
72 /// \brief Set the input nodes for this CNode.
73 ///
74 /// \param[in] inputs Input nodes.
75 void set_inputs(const std::vector<AnfNodePtr> &inputs);
76
77 /// \brief Add an input node to this CNode.
78 ///
79 /// \param[in] input the input node to be added.
80 void add_input(const AnfNodePtr &input);
81
82 /// \brief Set fullname_with_scope for this CNode.
83 ///
84 /// \param[in] full_name The fullname_with_scope.
85 void set_fullname_with_scope(const std::string &full_name);
86
87 /// \brief Add a new attribute to this CNode.
88 ///
89 /// \param[in] name The name of the new attribute.
90 /// \param[in] attr The value of the new attribute.
91 void AddAttr(const std::string &name, const ValuePtr &attr);
92
93 /// \brief Erase the attribute with the given name.
94 ///
95 /// \param[in] name The name of attribute.
96 void EraseAttr(const std::string &name);
97
98 /// \brief Get the attribute with the given name.
99 ///
100 /// \param[in] name The name of attribute.
101 /// \return Attribute.
102 ValuePtr GetAttr(const std::string &name) const;
103 };
104
105 using CNodePtr = SharedPtr<CNode>;
106
107 /// \brief Parameter represents the parameter inputs of a function.
108 class MIND_API Parameter : public AnfNode {
109 public:
110 MIND_API_BASE_MEMBER(Parameter);
111
112 /// \brief Get the name of this Parameter.
113 ///
114 /// \return The name.
115 std::string name() const;
116
117 /// \brief Set the name of this Parameter.
118 ///
119 /// \param[in] name The name.
120 void set_name(const std::string &name);
121
122 /// \brief Check if there is a default parameter.
123 ///
124 /// \return True if this Parameter has a default parameter, otherwise false.
125 bool has_default() const;
126
127 /// \brief Set the default parameter.
128 ///
129 /// \param[in] param The default parameter.
130 void set_default_param(const ValuePtr ¶m);
131
132 /// \brief Get the default parameter.
133 ///
134 /// \return The default parameter.
135 ValuePtr default_param() const;
136 };
137
138 using ParameterPtr = SharedPtr<Parameter>;
139
140 /// \brief ValueNode is a graph node that hold a value.
141 class MIND_API ValueNode : public AnfNode {
142 public:
143 MIND_API_BASE_MEMBER(ValueNode);
144
145 /// \brief Create ValueNode with the given value.
146 ///
147 /// \param[in] value The value of this ValueNode.
148 explicit ValueNode(const ValuePtr &value);
149
150 /// \brief Get the value of this ValueNode.
151 ///
152 /// \return The value.
153 ValuePtr value() const;
154 };
155
156 using ValueNodePtr = SharedPtr<ValueNode>;
157
158 // === ANF utility functions === //
159
160 /// \brief Create a ValueNode with the given value.
161 ///
162 /// \param[in] value The given value.
163 ///
164 /// \return The created ValueNode.
165 template <typename T, typename = typename std::enable_if_t<std::is_base_of_v<Value, T>, T>>
NewValueNode(const SharedPtr<T> & value)166 inline ValueNodePtr NewValueNode(const SharedPtr<T> &value) {
167 return MakeShared<ValueNode>(value);
168 }
169
170 /// \brief Create a ValueNode with the given primitive type value.
171 ///
172 /// \param[in] value The given primitive type value.
173 ///
174 /// \return The created ValueNode.
175 template <typename T>
NewValueNode(T value)176 inline ValueNodePtr NewValueNode(T value) {
177 return NewValueNode(MakeValue(value));
178 }
179
180 /// \brief Get the value from a node if it is a ValueNode.
181 ///
182 /// \param[in] node The node which may hold a value.
183 ///
184 /// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set.
GetValueNode(const AnfNodePtr & node)185 inline ValuePtr GetValueNode(const AnfNodePtr &node) {
186 if (node == nullptr) {
187 return nullptr;
188 }
189 auto value_node = node->cast<ValueNodePtr>();
190 if (value_node == nullptr) {
191 return nullptr;
192 }
193 return value_node->value();
194 }
195
196 /// \brief Get the value with the given type from a node if it is a ValueNode.
197 ///
198 /// \param[in] node The node which may hold a value.
199 ///
200 /// \return A pointer to the value, nullptr if the node is not a ValueNode, or value not set, or value type is mismatch.
201 template <typename T, typename = typename std::enable_if_t<
202 is_wrapper_ptr<T>::value && std::is_base_of_v<Value, typename T::element_type>, T>>
GetValueNode(const AnfNodePtr & node)203 inline T GetValueNode(const AnfNodePtr &node) {
204 auto value = GetValueNode(node);
205 if (value == nullptr) {
206 return nullptr;
207 }
208 return value->cast<T>();
209 }
210
211 /// \brief Check whether the given node is a cnode with the given Primitive as the first input.
212 ///
213 /// \param[in] node The given node to be checked.
214 /// \param[in] prim The Primitive value, nullptr means match any Primitive.
215 ///
216 /// \return True if the node is cnode and the first input is the given Primitive, false otherwise.
217 MIND_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &prim = nullptr);
218
219 /// \brief Check whether the given node is a ValueNode with the given Primitive.
220 ///
221 /// \param[in] node The given node to be checked.
222 /// \param[in] prim The Primitive value.
223 ///
224 /// \return True if the given node is a ValueNode with the given Primitive, false otherwise.
225 MIND_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &prim);
226
227 /// \brief Check if a node is a data node.
228 /// Some nodes may be used internally to pass some non-data states, those nodes are not data nodes.
229 ///
230 /// \param[in] node The node to be checked.
231 ///
232 /// \return True if the node is a data node, false otherwise.
233 MIND_API bool IsDataNode(const AnfNodePtr &node);
234 } // namespace mindspore::api
235 #endif // MINDSPORE_CORE_MINDAPI_IR_ANF_H_
236