1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_IR_ANF_H_
20 #define MINDSPORE_CORE_IR_ANF_H_
21
22 #include <functional>
23 #include <string>
24 #include <vector>
25 #include <memory>
26 #include <unordered_map>
27 #include <utility>
28 #include <set>
29
30 #include "base/base.h"
31 #include "base/user_data.h"
32 #include "base/effect_info.h"
33 #include "ir/kernel_info_dev.h"
34 #include "ir/scope.h"
35 #include "ir/primal_attr.h"
36 #include "ir/primal_debug_info.h"
37 #include "utils/info.h"
38 #include "utils/ms_utils.h"
39
40 // A MindSpore ANF IR defined here.
41 // with BNF followed:
42 // <ANode> ::= Scalar | Named | Tensor | Var |
43 // Prim | MetaFuncGraph | FuncGraph | Type|
44 // Shape | Param
45 // <CNode> ::= (<ANode> ...)
46 // <AnfNode> ::= <CNode> | <ANode>
47 // ANode: Atomic Node
48 // CNode: Complex Node
49 namespace mindspore {
50 namespace abstract {
51 class BaseShape;
52 class AbstractBase;
53 } // namespace abstract
54 using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
55 using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
56 using AbstractBasePtrList = std::vector<AbstractBasePtr>;
57
58 class Value;
59 using ValuePtr = std::shared_ptr<Value>;
60 using ValuePtrList = std::vector<ValuePtr>;
61
62 class ValueNode;
63 using ValueNodePtr = std::shared_ptr<ValueNode>;
64
65 class CNode;
66 using CNodePtr = std::shared_ptr<CNode>;
67 using CNodePtrList = std::vector<CNodePtr>;
68 using CNodeWeakPtr = std::weak_ptr<CNode>;
69
70 class FuncGraph;
71 using FuncGraphSet = OrderedSet<FuncGraphPtr>;
72 using FuncGraphVector = std::vector<FuncGraphPtr>;
73
74 class Primitive;
75 using PrimitivePtr = std::shared_ptr<Primitive>;
76
77 class BaseRef;
78
79 class Var;
80 using VarPtr = std::shared_ptr<Var>;
81
82 class AnfIrVisitor;
83
84 class ParamInfo;
85 using ParamInfoPtr = std::shared_ptr<ParamInfo>;
86
87 // AnfNode is the basic class of the IR definition derived from Base.
88 // Only two types of nodes are derived: CNode and ANode.
89 // Methods:
90 // func_graph: return FuncGraph that this AnfNode belongs to.
91 // scope: return the scope namespace of this AnfNode. Set it using set_scope.
92 // abstract: return the cached inferred abstract value. It contains type, shape
93 // value. Set New cache using set_abstract.
94 // intermediate_abstract: return the cached inferring abstract value.
95 // Type/Shape: return the related info of this AnfNode. When this AnfNode is an
96 // input of other CNodes, you can get the related info by this method.
97 // debug_info: return the information retrieved from parser. Set it using set_debug_info.
98 // fullname_with_scope: return the detailed debug info.
99 class MS_CORE_API AnfNode : public Base {
100 public:
AnfNode(const FuncGraphPtr & func_graph)101 explicit AnfNode(const FuncGraphPtr &func_graph)
102 : func_graph_(FuncGraphWeakPtr(func_graph)),
103 abstract_(nullptr),
104 intermediate_abstract_(nullptr),
105 debug_info_(std::make_shared<NodeDebugInfo>()),
106 fullname_with_scope_(""),
107 hash_(std::hash<const AnfNode *>()),
108 kernel_info_(nullptr),
109 stage_(-1),
110 need_grad_(false),
111 interpret_(false),
112 interpreted_node_(nullptr) {
113 scope_ = ScopeManager::GetInstance().GetCurrentScope();
114 }
115
116 ~AnfNode() override = default;
117 MS_DECLARE_PARENT(AnfNode, Base);
118
accept(AnfIrVisitor *)119 virtual void accept(AnfIrVisitor *) {}
func_graph()120 FuncGraphPtr func_graph() const { return func_graph_.lock(); }
121
set_func_graph(const FuncGraphPtr & func_graph)122 virtual void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
123
scope()124 ScopePtr scope() { return scope_; }
set_scope(const ScopePtr & scope)125 void set_scope(const ScopePtr &scope) { scope_ = scope; }
126
kernel_info()127 const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); }
kernel_info()128 KernelInfoDevice *kernel_info() { return kernel_info_.get(); }
kernel_info_ptr()129 const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
set_kernel_info(const KernelInfoDevicePtr & kernel_info)130 void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; }
131
abstract()132 const AbstractBasePtr &abstract() const { return abstract_; }
set_abstract(const AbstractBasePtr & abs)133 void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; }
134
intermediate_abstract()135 AbstractBasePtr intermediate_abstract() { return intermediate_abstract_; }
set_intermediate_abstract(const AbstractBasePtr & abs)136 void set_intermediate_abstract(const AbstractBasePtr &abs) { intermediate_abstract_ = abs; }
137
debug_info()138 NodeDebugInfoPtr debug_info() {
139 MS_EXCEPTION_IF_NULL(debug_info_);
140 if (debug_info_->get_node() == nullptr) {
141 debug_info_->set_node(shared_from_base<AnfNode>());
142 }
143 return debug_info_;
144 }
set_debug_info(const NodeDebugInfoPtr & debug_info)145 void set_debug_info(const NodeDebugInfoPtr &debug_info) {
146 MS_EXCEPTION_IF_NULL(debug_info);
147 debug_info_ = debug_info;
148 if (debug_info_->get_node() == nullptr) {
149 debug_info_->set_node(shared_from_base<AnfNode>());
150 }
151 }
152
153 TypePtr Type() const;
154 BaseShapePtr Shape() const;
155
hash()156 std::size_t hash() const override { return this->hash_(this); }
fullname_with_scope()157 virtual std::string fullname_with_scope() { return ""; }
UniqueName()158 std::string UniqueName() { return fullname_with_scope() + "_" + UniqueId(); }
159
160 virtual std::string DebugString(int recursive_level = 1) const { return ToString(); }
DebugString(bool recursive)161 virtual std::string DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
162 std::string ToString() const override;
dump()163 void dump() const override { std::cout << DebugString() << std::endl; }
UniqueId()164 std::string UniqueId() { return std::to_string(debug_info()->unique_id()); }
UniqueIdThroughCopy()165 std::string UniqueIdThroughCopy() { return std::to_string(debug_info()->unique_id_through_copy()); }
166 virtual bool operator==(const AnfNode &other) const { return &other == this; }
167 friend std::ostream &operator<<(std::ostream &os, const AnfNode &node) {
168 os << node.ToString();
169 return os;
170 }
171 size_t seen_{0};
172 size_t extra_seen_{0};
173
174 template <typename T>
set_user_data(const std::string & key,const std::shared_ptr<T> & value)175 void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
176 user_data_.set<T>(key, value);
177 }
178
179 template <typename T>
set_user_data(const std::shared_ptr<T> & value)180 void set_user_data(const std::shared_ptr<T> &value) {
181 user_data_.set<T>(T::key, value);
182 }
183
184 template <typename T>
user_data(const std::string & key)185 std::shared_ptr<T> user_data(const std::string &key) const {
186 return user_data_.get<T>(key);
187 }
188
189 template <typename T>
user_data()190 std::shared_ptr<T> user_data() const {
191 return user_data_.get<T>(T::key);
192 }
193
has_user_data(const std::string & key)194 bool has_user_data(const std::string &key) const { return user_data_.has(key); }
195
196 template <typename T>
has_user_data()197 bool has_user_data() const {
198 return user_data_.has(T::key);
199 }
200
CloneUserData(const AnfNodePtr & node)201 void CloneUserData(const AnfNodePtr &node) { user_data_ = node->user_data_; }
202
stage()203 int64_t stage() { return stage_; }
set_stage(const int & stage)204 void set_stage(const int &stage) { stage_ = stage; }
205
grad()206 bool grad() { return need_grad_; }
set_grad(const bool & need_grad)207 void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
208
interpret()209 bool interpret() { return interpret_; }
set_interpret(const bool & interpret)210 void set_interpret(const bool &interpret) { interpret_ = interpret; }
211
interpreted_node()212 AnfNodePtr interpreted_node() { return interpreted_node_; }
set_interpreted_node(const AnfNodePtr & node)213 void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; }
214
215 protected:
216 // Hold a weak ref to Graph as Graph also hold ref to AnfNode.
217 // Otherwise, func_graph_ and AnfNode will make a reference cycle.
218 FuncGraphWeakPtr func_graph_;
219 AbstractBasePtr abstract_;
220 AbstractBasePtr intermediate_abstract_;
221 NodeDebugInfoPtr debug_info_;
222 std::string fullname_with_scope_;
223
224 private:
225 std::hash<const AnfNode *> hash_;
226 ScopePtr scope_;
227 KernelInfoDevicePtr kernel_info_;
228 UserData user_data_;
229 int64_t stage_;
230 bool need_grad_;
231 bool interpret_;
232 AnfNodePtr interpreted_node_;
233 };
234
235 // CNode represents the complex node with a set of arguments.
236 // Fields:
237 // inputs_: represents all of the inputs for this CNode.
238 // Using input(i) to get the index i input.
239 // Using inputs() to get all the inputs as a vector.
240 // Using add_input(input) to append a new input for a CNode.
241 // Using set_input(i, input) to change some input of these inputs.
242 // Using set_inputs(inputs) to refresh all of the inputs of a CNode.
243 // func_graph_as_var_: used in opt pattern matching to match a real FuncGraph.
244 // stop_gradient_: a flag used to stop gradient.
245 // Using stop_gradient() to get this flag, mainly used in ad.
246 // Using set_stop_gradient() to set this flag.
247 class MS_CORE_API CNode : public AnfNode, public EffectInfoHolder {
248 public:
249 CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph);
CNode(const std::vector<AnfNodePtr> & inputs,const VarPtr & func_graph_as_var)250 CNode(const std::vector<AnfNodePtr> &inputs, const VarPtr &func_graph_as_var)
251 : AnfNode(nullptr),
252 inputs_(inputs),
253 func_graph_as_var_(func_graph_as_var),
254 stop_gradient_(false),
255 input_tensor_num_(-1) {
256 primal_attrs_ = PrimalAttrManager::GetInstance().GetCurrentPrimalAttr();
257 primal_debug_infos_ = PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo();
258 }
259
260 ~CNode() override = default;
261 MS_DECLARE_PARENT(CNode, AnfNode);
262
263 void accept(AnfIrVisitor *v) override;
264 // check whether this cnode has some primitive value as the first input.
265 bool IsApply(const PrimitivePtr &) const;
266
size()267 const size_t size() const { return inputs_.size(); }
268 const AnfNodePtr &input(size_t i) const;
inputs()269 const std::vector<AnfNodePtr> &inputs() const { return inputs_; }
270 void add_input(const AnfNodePtr &input);
271 void set_input(size_t i, const AnfNodePtr &input);
272 void set_inputs(const std::vector<AnfNodePtr> &inputs);
273
add_input_value(const ValuePtr & input_value,const std::string & id)274 void add_input_value(const ValuePtr &input_value, const std::string &id) {
275 inputs_value_.push_back(std::make_pair(input_value, id));
276 }
clear_inputs_value()277 void clear_inputs_value() { inputs_value_.clear(); }
set_inputs_value(const std::vector<std::pair<ValuePtr,std::string>> & values)278 void set_inputs_value(const std::vector<std::pair<ValuePtr, std::string>> &values) { inputs_value_ = values; }
inputs_value()279 const std::vector<std::pair<ValuePtr, std::string>> &inputs_value() const { return inputs_value_; }
280
set_forward(const ValueNodePtr & forward,const std::string & id)281 void set_forward(const ValueNodePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); }
forward()282 const std::pair<ValueNodePtr, std::string> &forward() const { return output_value_; }
283
stop_gradient()284 bool stop_gradient() const { return stop_gradient_; }
set_stop_gradient(bool stop_gradient)285 void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
286
287 std::string fullname_with_scope() override;
set_fullname_with_scope(const std::string full_name)288 void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; }
289 std::string DebugString(int recursive_level = 1) const override;
DebugString(bool recursive)290 std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
291
set_in_forward_flag(bool flag)292 void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
in_forward_flag()293 bool in_forward_flag() const { return in_forward_flag_; }
294
set_load_flag(bool is_load)295 void set_load_flag(bool is_load) { is_load_ = is_load; }
get_load_flag()296 bool get_load_flag() { return is_load_; }
297
func_graph_as_var()298 VarPtr func_graph_as_var() const { return func_graph_as_var_; }
299
attrs()300 const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
set_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)301 void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
302 attrs_.insert(attrs.cbegin(), attrs.cend());
303 }
304
AddAttr(const std::string & name,const ValuePtr & attr)305 void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
EraseAttr(const std::string & name)306 void EraseAttr(const std::string &name) { (void)attrs_.erase(name); }
GetAttr(const std::string & name)307 ValuePtr GetAttr(const std::string &name) const {
308 auto iter = attrs_.find(name);
309 return iter == attrs_.cend() ? nullptr : iter->second;
310 }
HasAttr(const std::string & name)311 bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }
input_tensor_num()312 ssize_t input_tensor_num() const { return input_tensor_num_; }
313
primal_attrs()314 const std::unordered_map<std::string, ValuePtr> &primal_attrs() const { return primal_attrs_; }
set_primal_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)315 void set_primal_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
316 primal_attrs_.insert(attrs.cbegin(), attrs.cend());
317 }
AddPrimalAttr(const std::string & name,const ValuePtr & attr)318 void AddPrimalAttr(const std::string &name, const ValuePtr &attr) { primal_attrs_[name] = attr; }
ErasePrimalAttr(const std::string & name)319 void ErasePrimalAttr(const std::string &name) { (void)primal_attrs_.erase(name); }
GetPrimalAttr(const std::string & name)320 ValuePtr GetPrimalAttr(const std::string &name) const {
321 auto iter = primal_attrs_.find(name);
322 return iter == primal_attrs_.cend() ? nullptr : iter->second;
323 }
HasPrimalAttr(const std::string & name)324 bool HasPrimalAttr(const std::string &name) const { return primal_attrs_.find(name) != attrs_.cend(); }
325
primal_debug_infos()326 std::vector<NodeDebugInfoPtr> primal_debug_infos() { return primal_debug_infos_; }
327
set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> & debug_infos)328 void set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> &debug_infos) {
329 primal_debug_infos_.insert(primal_debug_infos_.end(), debug_infos.begin(), debug_infos.end());
330 }
331
AddPrimalDebugInfo(const NodeDebugInfoPtr debug_info)332 void AddPrimalDebugInfo(const NodeDebugInfoPtr debug_info) {
333 if (std::find(primal_debug_infos_.begin(), primal_debug_infos_.end(), debug_info) != primal_debug_infos_.end()) {
334 MS_LOG(EXCEPTION) << "Debug_info already in primal_debug_infos_ vector";
335 }
336 primal_debug_infos_.push_back(debug_info);
337 }
338
CloneCNodeInfo(const CNodePtr & node)339 void CloneCNodeInfo(const CNodePtr &node) {
340 MS_EXCEPTION_IF_NULL(node);
341 set_abstract(node->abstract());
342 set_forward(node->forward().first, node->forward().second);
343 set_inputs_value(node->inputs_value());
344 set_attrs(node->attrs());
345 set_primal_attrs(node->primal_attrs());
346 set_load_flag(node->get_load_flag());
347 CloneUserData(node);
348 set_kernel_info(node->kernel_info_ptr());
349 set_primal_debug_infos(node->primal_debug_infos());
350 }
351
set_input_tensor_num(ssize_t input_tensor_num)352 void set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; }
353
354 // Is effect have been handled.
IsEffectHandled()355 bool IsEffectHandled() const { return effect_handled_; }
356
357 // Set effect handled or not.
SetEffectHandled(bool handled)358 void SetEffectHandled(bool handled) { effect_handled_ = handled; }
359
360 private:
361 std::vector<AnfNodePtr> inputs_;
362 VarPtr func_graph_as_var_;
363 bool stop_gradient_;
364 bool in_forward_flag_ = false;
365 bool effect_handled_ = false;
366 bool is_load_ = false;
367 // inputs_value_ store cnode input value and id in pynative mode
368 // output_value_ store cnode value and id in pynative mode
369 std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
370 std::pair<ValueNodePtr, std::string> output_value_;
371 std::unordered_map<std::string, ValuePtr> attrs_;
372 std::unordered_map<std::string, ValuePtr> primal_attrs_;
373 std::vector<NodeDebugInfoPtr> primal_debug_infos_;
374 ssize_t input_tensor_num_ = -1;
375 };
376
377 // ANode represents the atomic node. It's derived Parameter and ValueNode.
378 class MS_CORE_API ANode : public AnfNode {
379 public:
ANode()380 ANode() : AnfNode(nullptr) {}
ANode(const FuncGraphPtr & func_graph)381 explicit ANode(const FuncGraphPtr &func_graph) : AnfNode(func_graph) {}
382 virtual ~ANode() = default;
383
384 MS_DECLARE_PARENT(ANode, AnfNode);
385 };
386
387 // Parameter represents the parameter inputs of a function. They have no value.
388 // Attributes:
389 // default_param_value_: used to hold the inputting tensor of the model.
390 class MS_CORE_API Parameter : public ANode {
391 public:
Parameter(const FuncGraphPtr & func_graph)392 explicit Parameter(const FuncGraphPtr &func_graph)
393 : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {}
394 ~Parameter() override = default;
395 MS_DECLARE_PARENT(Parameter, ANode);
396
397 void accept(AnfIrVisitor *v) override;
398 std::string DebugString(int recursive_level = 1) const override;
name()399 std::string name() const { return name_; }
set_name(const std::string & name)400 void set_name(const std::string &name) { name_ = name; }
fullname_with_scope()401 std::string fullname_with_scope() override { return name(); }
402
has_default()403 bool has_default() const { return has_default_; }
set_default_param(ValuePtr param)404 void set_default_param(ValuePtr param) {
405 default_param_ = param;
406 has_default_ = true;
407 }
default_param()408 ValuePtr default_param() const { return default_param_; }
409 ParamInfoPtr param_info() const;
410
IncreaseUsedGraphCount()411 void IncreaseUsedGraphCount() { used_graph_count_++; }
DecreaseUsedGraphCount()412 void DecreaseUsedGraphCount() { used_graph_count_--; }
used_graph_count()413 int used_graph_count() const { return used_graph_count_; }
414
415 bool operator==(const AnfNode &other) const override {
416 if (!other.isa<Parameter>()) {
417 return false;
418 }
419 auto p = static_cast<const Parameter &>(other);
420 if (name_.length() > 0 && p.name_.length() > 0) {
421 return p.name_ == name_;
422 }
423 return shared_from_this() == other.shared_from_this();
424 }
425
SetNotUsedByRealKernelInGraph(uint32_t graph_id)426 void SetNotUsedByRealKernelInGraph(uint32_t graph_id) { (void)not_used_in_graphs_.insert(graph_id); }
427
IsUsedByRealKernelInGraph(uint32_t graph_id)428 bool IsUsedByRealKernelInGraph(uint32_t graph_id) const {
429 if (not_used_in_graphs_.find(graph_id) != not_used_in_graphs_.end()) {
430 return false;
431 }
432 return true;
433 }
434
set_has_dynamic_shape(bool flag)435 void set_has_dynamic_shape(bool flag) { has_dynamic_shape_ = flag; }
has_dynamic_shape()436 bool has_dynamic_shape() const { return has_dynamic_shape_; }
437
set_fracz_group(int64_t fracz_group)438 void set_fracz_group(int64_t fracz_group) { fracz_group_ = fracz_group; }
fracz_group()439 int64_t fracz_group() { return fracz_group_; }
440
441 private:
442 std::string name_;
443 bool has_default_;
444 std::set<uint32_t> not_used_in_graphs_;
445 bool has_dynamic_shape_ = false;
446 ValuePtr default_param_;
447 // The count of graphs using the parameter.
448 int used_graph_count_;
449 // groups attr in FracZ format
450 int64_t fracz_group_ = 1;
451 };
452 using ParameterPtr = std::shared_ptr<Parameter>;
453
454 // Value is used to represent the atomic expression mentioned in BNF.
455 // It mainly be stored in ValueNode. Value and ValueNode is related definition.
456 class MS_CORE_API Value : public Base {
457 public:
458 Value() = default;
Value(const TypePtr t)459 explicit Value(const TypePtr t) : type_(t) {}
Value(const Value & other)460 Value(const Value &other) : Base(other) { this->type_ = other.type_; }
461 ~Value() override = default;
MS_DECLARE_PARENT(Value,Base)462 MS_DECLARE_PARENT(Value, Base)
463
464 TypePtr type() const { return type_; }
ToAbstract()465 virtual abstract::AbstractBasePtr ToAbstract() {
466 MS_LOG(EXCEPTION) << "ToAbstract error";
467 abstract::AbstractBasePtr result;
468 return result;
469 }
470
471 virtual bool operator==(const Value &rhs) const = 0;
472 virtual Value &operator=(const Value &other) {
473 if (&other == this) {
474 return *this;
475 }
476 this->type_ = other.type_;
477 return *this;
478 }
479
480 protected:
481 TypePtr type_{nullptr};
482 };
483
484 // ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
485 // does not belong to any particular function graph.
486 class MS_CORE_API ValueNode : public ANode {
487 public:
ValueNode(const ValuePtr & value)488 explicit ValueNode(const ValuePtr &value) : value_(value) {}
489 ~ValueNode() override = default;
490 MS_DECLARE_PARENT(ValueNode, ANode);
491
set_func_graph(const FuncGraphPtr & func_graph)492 void set_func_graph(const FuncGraphPtr &func_graph) override {
493 MS_EXCEPTION(ValueError) << "ValueNode should not set its func_graph.";
494 }
495
496 void accept(AnfIrVisitor *v) override;
set_value(const ValuePtr & value)497 void set_value(const ValuePtr &value) { value_ = value; }
value()498 const ValuePtr &value() const { return value_; }
499 std::string fullname_with_scope() override;
500
set_has_new_value(bool flag)501 void set_has_new_value(bool flag) { has_new_value_ = flag; }
has_new_value()502 bool has_new_value() const { return has_new_value_; }
503
used_graph_count()504 size_t used_graph_count() const { return used_graph_count_; }
set_used_graph_count(size_t used_graph_count)505 void set_used_graph_count(size_t used_graph_count) { used_graph_count_ = used_graph_count; }
506
507 std::string ToString() const override;
508 std::string DebugString(int recursive_level = 1) const override;
DebugString(bool recursive)509 std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
510
511 bool operator==(const AnfNode &other) const override {
512 if (!other.isa<ValueNode>()) {
513 return false;
514 }
515 auto v = static_cast<const ValueNode &>(other);
516 return *v.value() == *value();
517 }
518 friend std::ostream &operator<<(std::ostream &os, const ValueNodePtr &node) {
519 MS_EXCEPTION_IF_NULL(node);
520 os << node->ToString();
521 return os;
522 }
523
524 private:
525 ValuePtr value_;
526 bool has_new_value_ = false;
527 size_t used_graph_count_{0};
528 };
529
530 template <typename T>
531 struct ImmTraits {};
532
533 #define IMM_TRAITS(typeimm, prototype) \
534 template <> \
535 struct ImmTraits<prototype> { \
536 using type = typeimm; \
537 };
538
MakeValue(const ValuePtr & value)539 inline ValuePtr MakeValue(const ValuePtr &value) { return value; }
540
541 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
MakeValue(S v)542 inline ValuePtr MakeValue(S v) {
543 return std::make_shared<U>(v);
544 }
545
546 template <typename S, typename U = typename ImmTraits<S>::type>
GetValue(const ValuePtr & value)547 static S GetValue(const ValuePtr &value) {
548 MS_EXCEPTION_IF_NULL(value);
549 U imm = value->cast<U>();
550 if (imm == nullptr) {
551 MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name();
552 }
553 return imm->value();
554 }
555
556 template <typename S,
557 typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
558 S>::type * = nullptr>
GetValue(const ValuePtr & value)559 static S GetValue(const ValuePtr &value) {
560 MS_EXCEPTION_IF_NULL(value);
561 S v = value->cast<S>();
562 if (v == nullptr) {
563 MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name();
564 }
565 return v;
566 }
567
568 std::string GetCNodeFuncName(CNodePtr cnode);
569
570 // used to get FuncGraphPtr from a cnode first input
571 FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node);
572
573 // used to check whether an AnfNode is a cnode with a kind of Primitive as first input
574 bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
575
576 // used to get PrimitivePtr from a cnode first input
577 PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
578
579 // used to check whether an AnfNode is a valuenode having some Primitive value
580 MS_CORE_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value);
581
582 // Check whether two primitives are same.
583 bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2);
584
585 // Get number of AbstractMonad
586 size_t GetAbstractMonadNum(const AbstractBasePtrList &args);
587
588 // Check whether the given node has monad abstract.
589 bool HasAbstractMonad(const AnfNodePtr &node);
590
591 // Check whether the given node has U monad abstract.
592 bool HasAbstractUMonad(const AnfNodePtr &node);
593
594 // Check whether the given node has IO monad abstract.
595 bool HasAbstractIOMonad(const AnfNodePtr &node);
596
597 // Gets primitive attribute value as a bool flag.
598 bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr);
599
600 // Gets effect info from a primitive by its attributes.
601 EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim);
602
603 struct MonadState {
604 AnfNodePtr u{nullptr};
605 AnfNodePtr io{nullptr};
606 };
607
608 // Get Memory/IO monad state from node.
609 MonadState GetMonadState(const AnfNodePtr &node, const AnfNodePtr &skip_input = nullptr);
610
611 // Check if two state is equivalent.
612 bool IsStateEquivalent(const MonadState &state1, const MonadState &state2);
613
614 // Check if monad state is strict equivalent for the connected two nodes.
615 bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
616
617 // Check if monad state is equivalent for the connected two nodes, not strict but more faster.
618 bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
619
620 // used to check whether a ValueNode has some kind of value
621 template <typename T>
IsValueNode(const AnfNodePtr & node)622 static bool IsValueNode(const AnfNodePtr &node) {
623 MS_EXCEPTION_IF_NULL(node);
624 auto anode = node->cast<ValueNodePtr>();
625 if (anode != nullptr) {
626 auto value = anode->value();
627 if (value == nullptr) {
628 MS_LOG(EXCEPTION) << "Const value is nullptr.";
629 }
630 return value->isa<T>();
631 }
632 return false;
633 }
634
GetValueNode(const AnfNodePtr & node)635 inline ValuePtr GetValueNode(const AnfNodePtr &node) {
636 MS_EXCEPTION_IF_NULL(node);
637 if (!node->isa<ValueNode>()) {
638 return nullptr;
639 }
640 return node->cast<ValueNodePtr>()->value();
641 }
642
643 template <typename S,
644 typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
645 S>::type * = nullptr>
GetValueNode(const AnfNodePtr & node)646 inline S GetValueNode(const AnfNodePtr &node) {
647 auto value = GetValueNode(node);
648 if (value == nullptr) {
649 return nullptr;
650 }
651 auto s = value->cast<S>();
652 return s;
653 }
654
655 size_t NewSeenGeneration();
656
657 namespace id_generator {
658 std::string get_id(const AnfNodePtr &node);
659 void reset_id();
660 } // namespace id_generator
661 using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
662 using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
663 std::string GetCNodeTarget(const AnfNodePtr &node);
664 std::string GetOriginNodeTarget(const AnfNodePtr &node);
665 bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
666 struct GraphSegment {
GraphSegmentGraphSegment667 GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
AddPreSegmentGraphSegment668 void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); }
669 std::vector<AnfNodePtr> nodes_;
670 std::set<std::shared_ptr<GraphSegment>> pre_segments_;
671 bool is_cut_{false};
672 uint32_t graph_id_{0};
673 };
674 using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
675 } // namespace mindspore
676
677 #endif // MINDSPORE_CORE_IR_ANF_H_
678