• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_ANF_H_
20 #define MINDSPORE_CORE_IR_ANF_H_
21 
22 #include <functional>
23 #include <string>
24 #include <vector>
25 #include <memory>
26 #include <unordered_map>
27 #include <utility>
28 #include <set>
29 
30 #include "base/base.h"
31 #include "base/user_data.h"
32 #include "base/effect_info.h"
33 #include "ir/kernel_info_dev.h"
34 #include "ir/scope.h"
35 #include "ir/primal_attr.h"
36 #include "ir/primal_debug_info.h"
37 #include "utils/info.h"
38 #include "utils/ms_utils.h"
39 
40 // A MindSpore ANF IR defined here.
41 // with BNF followed:
42 // <ANode> ::= Scalar | Named | Tensor  | Var |
43 //             Prim   | MetaFuncGraph | FuncGraph | Type|
44 //             Shape  | Param
45 // <CNode> ::= (<ANode> ...)
46 // <AnfNode> ::= <CNode> | <ANode>
47 // ANode: Atomic  Node
48 // CNode: Complex Node
49 namespace mindspore {
50 namespace abstract {
51 class BaseShape;
52 class AbstractBase;
53 }  // namespace abstract
54 using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
55 using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
56 using AbstractBasePtrList = std::vector<AbstractBasePtr>;
57 
58 class Value;
59 using ValuePtr = std::shared_ptr<Value>;
60 using ValuePtrList = std::vector<ValuePtr>;
61 
62 class ValueNode;
63 using ValueNodePtr = std::shared_ptr<ValueNode>;
64 
65 class CNode;
66 using CNodePtr = std::shared_ptr<CNode>;
67 using CNodePtrList = std::vector<CNodePtr>;
68 using CNodeWeakPtr = std::weak_ptr<CNode>;
69 
70 class FuncGraph;
71 using FuncGraphSet = OrderedSet<FuncGraphPtr>;
72 using FuncGraphVector = std::vector<FuncGraphPtr>;
73 
74 class Primitive;
75 using PrimitivePtr = std::shared_ptr<Primitive>;
76 
77 class BaseRef;
78 
79 class Var;
80 using VarPtr = std::shared_ptr<Var>;
81 
82 class AnfIrVisitor;
83 
84 class ParamInfo;
85 using ParamInfoPtr = std::shared_ptr<ParamInfo>;
86 
87 // AnfNode is the basic class of the IR definition derived from Base.
88 // Only two types of nodes are derived: CNode and ANode.
89 // Methods:
90 // func_graph: return FuncGraph that this AnfNode belongs to.
91 // scope: return the scope namespace of this AnfNode. Set it using set_scope.
92 // abstract: return the cached inferred abstract value. It contains type, shape
93 // value. Set New cache using set_abstract.
94 // intermediate_abstract: return the cached inferring abstract value.
95 // Type/Shape: return the related info of this AnfNode. When this AnfNode is an
96 // input of other CNodes, you can get the related info by this method.
97 // debug_info: return the information retrieved from parser. Set it using set_debug_info.
98 // fullname_with_scope: return the detailed debug info.
99 class MS_CORE_API AnfNode : public Base {
100  public:
AnfNode(const FuncGraphPtr & func_graph)101   explicit AnfNode(const FuncGraphPtr &func_graph)
102       : func_graph_(FuncGraphWeakPtr(func_graph)),
103         abstract_(nullptr),
104         intermediate_abstract_(nullptr),
105         debug_info_(std::make_shared<NodeDebugInfo>()),
106         fullname_with_scope_(""),
107         hash_(std::hash<const AnfNode *>()),
108         kernel_info_(nullptr),
109         stage_(-1),
110         need_grad_(false),
111         interpret_(false),
112         interpreted_node_(nullptr) {
113     scope_ = ScopeManager::GetInstance().GetCurrentScope();
114   }
115 
116   ~AnfNode() override = default;
117   MS_DECLARE_PARENT(AnfNode, Base);
118 
accept(AnfIrVisitor *)119   virtual void accept(AnfIrVisitor *) {}
func_graph()120   FuncGraphPtr func_graph() const { return func_graph_.lock(); }
121 
set_func_graph(const FuncGraphPtr & func_graph)122   virtual void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
123 
scope()124   ScopePtr scope() { return scope_; }
set_scope(const ScopePtr & scope)125   void set_scope(const ScopePtr &scope) { scope_ = scope; }
126 
kernel_info()127   const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); }
kernel_info()128   KernelInfoDevice *kernel_info() { return kernel_info_.get(); }
kernel_info_ptr()129   const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
set_kernel_info(const KernelInfoDevicePtr & kernel_info)130   void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; }
131 
abstract()132   const AbstractBasePtr &abstract() const { return abstract_; }
set_abstract(const AbstractBasePtr & abs)133   void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; }
134 
intermediate_abstract()135   AbstractBasePtr intermediate_abstract() { return intermediate_abstract_; }
set_intermediate_abstract(const AbstractBasePtr & abs)136   void set_intermediate_abstract(const AbstractBasePtr &abs) { intermediate_abstract_ = abs; }
137 
debug_info()138   NodeDebugInfoPtr debug_info() {
139     MS_EXCEPTION_IF_NULL(debug_info_);
140     if (debug_info_->get_node() == nullptr) {
141       debug_info_->set_node(shared_from_base<AnfNode>());
142     }
143     return debug_info_;
144   }
set_debug_info(const NodeDebugInfoPtr & debug_info)145   void set_debug_info(const NodeDebugInfoPtr &debug_info) {
146     MS_EXCEPTION_IF_NULL(debug_info);
147     debug_info_ = debug_info;
148     if (debug_info_->get_node() == nullptr) {
149       debug_info_->set_node(shared_from_base<AnfNode>());
150     }
151   }
152 
153   TypePtr Type() const;
154   BaseShapePtr Shape() const;
155 
hash()156   std::size_t hash() const override { return this->hash_(this); }
fullname_with_scope()157   virtual std::string fullname_with_scope() { return ""; }
UniqueName()158   std::string UniqueName() { return fullname_with_scope() + "_" + UniqueId(); }
159 
160   virtual std::string DebugString(int recursive_level = 1) const { return ToString(); }
DebugString(bool recursive)161   virtual std::string DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
162   std::string ToString() const override;
dump()163   void dump() const override { std::cout << DebugString() << std::endl; }
UniqueId()164   std::string UniqueId() { return std::to_string(debug_info()->unique_id()); }
UniqueIdThroughCopy()165   std::string UniqueIdThroughCopy() { return std::to_string(debug_info()->unique_id_through_copy()); }
166   virtual bool operator==(const AnfNode &other) const { return &other == this; }
167   friend std::ostream &operator<<(std::ostream &os, const AnfNode &node) {
168     os << node.ToString();
169     return os;
170   }
171   size_t seen_{0};
172   size_t extra_seen_{0};
173 
174   template <typename T>
set_user_data(const std::string & key,const std::shared_ptr<T> & value)175   void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
176     user_data_.set<T>(key, value);
177   }
178 
179   template <typename T>
set_user_data(const std::shared_ptr<T> & value)180   void set_user_data(const std::shared_ptr<T> &value) {
181     user_data_.set<T>(T::key, value);
182   }
183 
184   template <typename T>
user_data(const std::string & key)185   std::shared_ptr<T> user_data(const std::string &key) const {
186     return user_data_.get<T>(key);
187   }
188 
189   template <typename T>
user_data()190   std::shared_ptr<T> user_data() const {
191     return user_data_.get<T>(T::key);
192   }
193 
has_user_data(const std::string & key)194   bool has_user_data(const std::string &key) const { return user_data_.has(key); }
195 
196   template <typename T>
has_user_data()197   bool has_user_data() const {
198     return user_data_.has(T::key);
199   }
200 
CloneUserData(const AnfNodePtr & node)201   void CloneUserData(const AnfNodePtr &node) { user_data_ = node->user_data_; }
202 
stage()203   int64_t stage() { return stage_; }
set_stage(const int & stage)204   void set_stage(const int &stage) { stage_ = stage; }
205 
grad()206   bool grad() { return need_grad_; }
set_grad(const bool & need_grad)207   void set_grad(const bool &need_grad) { need_grad_ = need_grad; }
208 
interpret()209   bool interpret() { return interpret_; }
set_interpret(const bool & interpret)210   void set_interpret(const bool &interpret) { interpret_ = interpret; }
211 
interpreted_node()212   AnfNodePtr interpreted_node() { return interpreted_node_; }
set_interpreted_node(const AnfNodePtr & node)213   void set_interpreted_node(const AnfNodePtr &node) { interpreted_node_ = node; }
214 
215  protected:
216   // Hold a weak ref to Graph as Graph also hold ref to AnfNode.
217   // Otherwise, func_graph_ and AnfNode will make a reference cycle.
218   FuncGraphWeakPtr func_graph_;
219   AbstractBasePtr abstract_;
220   AbstractBasePtr intermediate_abstract_;
221   NodeDebugInfoPtr debug_info_;
222   std::string fullname_with_scope_;
223 
224  private:
225   std::hash<const AnfNode *> hash_;
226   ScopePtr scope_;
227   KernelInfoDevicePtr kernel_info_;
228   UserData user_data_;
229   int64_t stage_;
230   bool need_grad_;
231   bool interpret_;
232   AnfNodePtr interpreted_node_;
233 };
234 
235 // CNode represents the complex node with a set of arguments.
236 // Fields:
237 // inputs_: represents all of the inputs for this CNode.
238 // Using input(i) to get the index i input.
239 // Using inputs() to get all the inputs as a vector.
240 // Using add_input(input) to append a new input for a CNode.
241 // Using set_input(i, input) to change some input of these inputs.
242 // Using set_inputs(inputs) to refresh all of the inputs of a CNode.
243 // func_graph_as_var_: used in opt pattern matching to match a real FuncGraph.
244 // stop_gradient_: a flag used to stop gradient.
245 // Using stop_gradient() to get this flag, mainly used in ad.
246 // Using set_stop_gradient() to set this flag.
247 class MS_CORE_API CNode : public AnfNode, public EffectInfoHolder {
248  public:
249   CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph);
CNode(const std::vector<AnfNodePtr> & inputs,const VarPtr & func_graph_as_var)250   CNode(const std::vector<AnfNodePtr> &inputs, const VarPtr &func_graph_as_var)
251       : AnfNode(nullptr),
252         inputs_(inputs),
253         func_graph_as_var_(func_graph_as_var),
254         stop_gradient_(false),
255         input_tensor_num_(-1) {
256     primal_attrs_ = PrimalAttrManager::GetInstance().GetCurrentPrimalAttr();
257     primal_debug_infos_ = PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo();
258   }
259 
260   ~CNode() override = default;
261   MS_DECLARE_PARENT(CNode, AnfNode);
262 
263   void accept(AnfIrVisitor *v) override;
264   // check whether this cnode has some primitive value as the first input.
265   bool IsApply(const PrimitivePtr &) const;
266 
size()267   const size_t size() const { return inputs_.size(); }
268   const AnfNodePtr &input(size_t i) const;
inputs()269   const std::vector<AnfNodePtr> &inputs() const { return inputs_; }
270   void add_input(const AnfNodePtr &input);
271   void set_input(size_t i, const AnfNodePtr &input);
272   void set_inputs(const std::vector<AnfNodePtr> &inputs);
273 
add_input_value(const ValuePtr & input_value,const std::string & id)274   void add_input_value(const ValuePtr &input_value, const std::string &id) {
275     inputs_value_.push_back(std::make_pair(input_value, id));
276   }
clear_inputs_value()277   void clear_inputs_value() { inputs_value_.clear(); }
set_inputs_value(const std::vector<std::pair<ValuePtr,std::string>> & values)278   void set_inputs_value(const std::vector<std::pair<ValuePtr, std::string>> &values) { inputs_value_ = values; }
inputs_value()279   const std::vector<std::pair<ValuePtr, std::string>> &inputs_value() const { return inputs_value_; }
280 
set_forward(const ValueNodePtr & forward,const std::string & id)281   void set_forward(const ValueNodePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); }
forward()282   const std::pair<ValueNodePtr, std::string> &forward() const { return output_value_; }
283 
stop_gradient()284   bool stop_gradient() const { return stop_gradient_; }
set_stop_gradient(bool stop_gradient)285   void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
286 
287   std::string fullname_with_scope() override;
set_fullname_with_scope(const std::string full_name)288   void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; }
289   std::string DebugString(int recursive_level = 1) const override;
DebugString(bool recursive)290   std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
291 
set_in_forward_flag(bool flag)292   void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
in_forward_flag()293   bool in_forward_flag() const { return in_forward_flag_; }
294 
set_load_flag(bool is_load)295   void set_load_flag(bool is_load) { is_load_ = is_load; }
get_load_flag()296   bool get_load_flag() { return is_load_; }
297 
func_graph_as_var()298   VarPtr func_graph_as_var() const { return func_graph_as_var_; }
299 
attrs()300   const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
set_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)301   void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
302     attrs_.insert(attrs.cbegin(), attrs.cend());
303   }
304 
AddAttr(const std::string & name,const ValuePtr & attr)305   void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
EraseAttr(const std::string & name)306   void EraseAttr(const std::string &name) { (void)attrs_.erase(name); }
GetAttr(const std::string & name)307   ValuePtr GetAttr(const std::string &name) const {
308     auto iter = attrs_.find(name);
309     return iter == attrs_.cend() ? nullptr : iter->second;
310   }
HasAttr(const std::string & name)311   bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }
input_tensor_num()312   ssize_t input_tensor_num() const { return input_tensor_num_; }
313 
primal_attrs()314   const std::unordered_map<std::string, ValuePtr> &primal_attrs() const { return primal_attrs_; }
set_primal_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)315   void set_primal_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
316     primal_attrs_.insert(attrs.cbegin(), attrs.cend());
317   }
AddPrimalAttr(const std::string & name,const ValuePtr & attr)318   void AddPrimalAttr(const std::string &name, const ValuePtr &attr) { primal_attrs_[name] = attr; }
ErasePrimalAttr(const std::string & name)319   void ErasePrimalAttr(const std::string &name) { (void)primal_attrs_.erase(name); }
GetPrimalAttr(const std::string & name)320   ValuePtr GetPrimalAttr(const std::string &name) const {
321     auto iter = primal_attrs_.find(name);
322     return iter == primal_attrs_.cend() ? nullptr : iter->second;
323   }
HasPrimalAttr(const std::string & name)324   bool HasPrimalAttr(const std::string &name) const { return primal_attrs_.find(name) != attrs_.cend(); }
325 
primal_debug_infos()326   std::vector<NodeDebugInfoPtr> primal_debug_infos() { return primal_debug_infos_; }
327 
set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> & debug_infos)328   void set_primal_debug_infos(const std::vector<NodeDebugInfoPtr> &debug_infos) {
329     primal_debug_infos_.insert(primal_debug_infos_.end(), debug_infos.begin(), debug_infos.end());
330   }
331 
AddPrimalDebugInfo(const NodeDebugInfoPtr debug_info)332   void AddPrimalDebugInfo(const NodeDebugInfoPtr debug_info) {
333     if (std::find(primal_debug_infos_.begin(), primal_debug_infos_.end(), debug_info) != primal_debug_infos_.end()) {
334       MS_LOG(EXCEPTION) << "Debug_info already in primal_debug_infos_ vector";
335     }
336     primal_debug_infos_.push_back(debug_info);
337   }
338 
CloneCNodeInfo(const CNodePtr & node)339   void CloneCNodeInfo(const CNodePtr &node) {
340     MS_EXCEPTION_IF_NULL(node);
341     set_abstract(node->abstract());
342     set_forward(node->forward().first, node->forward().second);
343     set_inputs_value(node->inputs_value());
344     set_attrs(node->attrs());
345     set_primal_attrs(node->primal_attrs());
346     set_load_flag(node->get_load_flag());
347     CloneUserData(node);
348     set_kernel_info(node->kernel_info_ptr());
349     set_primal_debug_infos(node->primal_debug_infos());
350   }
351 
set_input_tensor_num(ssize_t input_tensor_num)352   void set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; }
353 
354   // Is effect have been handled.
IsEffectHandled()355   bool IsEffectHandled() const { return effect_handled_; }
356 
357   // Set effect handled or not.
SetEffectHandled(bool handled)358   void SetEffectHandled(bool handled) { effect_handled_ = handled; }
359 
360  private:
361   std::vector<AnfNodePtr> inputs_;
362   VarPtr func_graph_as_var_;
363   bool stop_gradient_;
364   bool in_forward_flag_ = false;
365   bool effect_handled_ = false;
366   bool is_load_ = false;
367   // inputs_value_ store cnode input value and id in pynative mode
368   // output_value_ store cnode value and id in pynative mode
369   std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
370   std::pair<ValueNodePtr, std::string> output_value_;
371   std::unordered_map<std::string, ValuePtr> attrs_;
372   std::unordered_map<std::string, ValuePtr> primal_attrs_;
373   std::vector<NodeDebugInfoPtr> primal_debug_infos_;
374   ssize_t input_tensor_num_ = -1;
375 };
376 
377 // ANode represents the atomic node. It's derived Parameter and ValueNode.
378 class MS_CORE_API ANode : public AnfNode {
379  public:
ANode()380   ANode() : AnfNode(nullptr) {}
ANode(const FuncGraphPtr & func_graph)381   explicit ANode(const FuncGraphPtr &func_graph) : AnfNode(func_graph) {}
382   virtual ~ANode() = default;
383 
384   MS_DECLARE_PARENT(ANode, AnfNode);
385 };
386 
387 // Parameter represents the parameter inputs of a function. They have no value.
388 // Attributes:
389 // default_param_value_: used to hold the inputting tensor of the model.
390 class MS_CORE_API Parameter : public ANode {
391  public:
Parameter(const FuncGraphPtr & func_graph)392   explicit Parameter(const FuncGraphPtr &func_graph)
393       : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {}
394   ~Parameter() override = default;
395   MS_DECLARE_PARENT(Parameter, ANode);
396 
397   void accept(AnfIrVisitor *v) override;
398   std::string DebugString(int recursive_level = 1) const override;
name()399   std::string name() const { return name_; }
set_name(const std::string & name)400   void set_name(const std::string &name) { name_ = name; }
fullname_with_scope()401   std::string fullname_with_scope() override { return name(); }
402 
has_default()403   bool has_default() const { return has_default_; }
set_default_param(ValuePtr param)404   void set_default_param(ValuePtr param) {
405     default_param_ = param;
406     has_default_ = true;
407   }
default_param()408   ValuePtr default_param() const { return default_param_; }
409   ParamInfoPtr param_info() const;
410 
IncreaseUsedGraphCount()411   void IncreaseUsedGraphCount() { used_graph_count_++; }
DecreaseUsedGraphCount()412   void DecreaseUsedGraphCount() { used_graph_count_--; }
used_graph_count()413   int used_graph_count() const { return used_graph_count_; }
414 
415   bool operator==(const AnfNode &other) const override {
416     if (!other.isa<Parameter>()) {
417       return false;
418     }
419     auto p = static_cast<const Parameter &>(other);
420     if (name_.length() > 0 && p.name_.length() > 0) {
421       return p.name_ == name_;
422     }
423     return shared_from_this() == other.shared_from_this();
424   }
425 
SetNotUsedByRealKernelInGraph(uint32_t graph_id)426   void SetNotUsedByRealKernelInGraph(uint32_t graph_id) { (void)not_used_in_graphs_.insert(graph_id); }
427 
IsUsedByRealKernelInGraph(uint32_t graph_id)428   bool IsUsedByRealKernelInGraph(uint32_t graph_id) const {
429     if (not_used_in_graphs_.find(graph_id) != not_used_in_graphs_.end()) {
430       return false;
431     }
432     return true;
433   }
434 
set_has_dynamic_shape(bool flag)435   void set_has_dynamic_shape(bool flag) { has_dynamic_shape_ = flag; }
has_dynamic_shape()436   bool has_dynamic_shape() const { return has_dynamic_shape_; }
437 
set_fracz_group(int64_t fracz_group)438   void set_fracz_group(int64_t fracz_group) { fracz_group_ = fracz_group; }
fracz_group()439   int64_t fracz_group() { return fracz_group_; }
440 
441  private:
442   std::string name_;
443   bool has_default_;
444   std::set<uint32_t> not_used_in_graphs_;
445   bool has_dynamic_shape_ = false;
446   ValuePtr default_param_;
447   // The count of graphs using the parameter.
448   int used_graph_count_;
449   // groups attr in FracZ format
450   int64_t fracz_group_ = 1;
451 };
452 using ParameterPtr = std::shared_ptr<Parameter>;
453 
454 // Value is used to represent the atomic expression mentioned in BNF.
455 // It mainly be stored in ValueNode. Value and ValueNode is related definition.
456 class MS_CORE_API Value : public Base {
457  public:
458   Value() = default;
Value(const TypePtr t)459   explicit Value(const TypePtr t) : type_(t) {}
Value(const Value & other)460   Value(const Value &other) : Base(other) { this->type_ = other.type_; }
461   ~Value() override = default;
MS_DECLARE_PARENT(Value,Base)462   MS_DECLARE_PARENT(Value, Base)
463 
464   TypePtr type() const { return type_; }
ToAbstract()465   virtual abstract::AbstractBasePtr ToAbstract() {
466     MS_LOG(EXCEPTION) << "ToAbstract error";
467     abstract::AbstractBasePtr result;
468     return result;
469   }
470 
471   virtual bool operator==(const Value &rhs) const = 0;
472   virtual Value &operator=(const Value &other) {
473     if (&other == this) {
474       return *this;
475     }
476     this->type_ = other.type_;
477     return *this;
478   }
479 
480  protected:
481   TypePtr type_{nullptr};
482 };
483 
484 // ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
485 // does not belong to any particular function graph.
486 class MS_CORE_API ValueNode : public ANode {
487  public:
ValueNode(const ValuePtr & value)488   explicit ValueNode(const ValuePtr &value) : value_(value) {}
489   ~ValueNode() override = default;
490   MS_DECLARE_PARENT(ValueNode, ANode);
491 
set_func_graph(const FuncGraphPtr & func_graph)492   void set_func_graph(const FuncGraphPtr &func_graph) override {
493     MS_EXCEPTION(ValueError) << "ValueNode should not set its func_graph.";
494   }
495 
496   void accept(AnfIrVisitor *v) override;
set_value(const ValuePtr & value)497   void set_value(const ValuePtr &value) { value_ = value; }
value()498   const ValuePtr &value() const { return value_; }
499   std::string fullname_with_scope() override;
500 
set_has_new_value(bool flag)501   void set_has_new_value(bool flag) { has_new_value_ = flag; }
has_new_value()502   bool has_new_value() const { return has_new_value_; }
503 
used_graph_count()504   size_t used_graph_count() const { return used_graph_count_; }
set_used_graph_count(size_t used_graph_count)505   void set_used_graph_count(size_t used_graph_count) { used_graph_count_ = used_graph_count; }
506 
507   std::string ToString() const override;
508   std::string DebugString(int recursive_level = 1) const override;
DebugString(bool recursive)509   std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
510 
511   bool operator==(const AnfNode &other) const override {
512     if (!other.isa<ValueNode>()) {
513       return false;
514     }
515     auto v = static_cast<const ValueNode &>(other);
516     return *v.value() == *value();
517   }
518   friend std::ostream &operator<<(std::ostream &os, const ValueNodePtr &node) {
519     MS_EXCEPTION_IF_NULL(node);
520     os << node->ToString();
521     return os;
522   }
523 
524  private:
525   ValuePtr value_;
526   bool has_new_value_ = false;
527   size_t used_graph_count_{0};
528 };
529 
530 template <typename T>
531 struct ImmTraits {};
532 
533 #define IMM_TRAITS(typeimm, prototype) \
534   template <>                          \
535   struct ImmTraits<prototype> {        \
536     using type = typeimm;              \
537   };
538 
MakeValue(const ValuePtr & value)539 inline ValuePtr MakeValue(const ValuePtr &value) { return value; }
540 
541 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
MakeValue(S v)542 inline ValuePtr MakeValue(S v) {
543   return std::make_shared<U>(v);
544 }
545 
546 template <typename S, typename U = typename ImmTraits<S>::type>
GetValue(const ValuePtr & value)547 static S GetValue(const ValuePtr &value) {
548   MS_EXCEPTION_IF_NULL(value);
549   U imm = value->cast<U>();
550   if (imm == nullptr) {
551     MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name();
552   }
553   return imm->value();
554 }
555 
556 template <typename S,
557           typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
558                                   S>::type * = nullptr>
GetValue(const ValuePtr & value)559 static S GetValue(const ValuePtr &value) {
560   MS_EXCEPTION_IF_NULL(value);
561   S v = value->cast<S>();
562   if (v == nullptr) {
563     MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name();
564   }
565   return v;
566 }
567 
568 std::string GetCNodeFuncName(CNodePtr cnode);
569 
570 // used to get FuncGraphPtr from a cnode first input
571 FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node);
572 
573 // used to check whether an AnfNode is a cnode with a kind of Primitive as first input
574 bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
575 
576 // used to get PrimitivePtr from a cnode first input
577 PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
578 
579 // used to check whether an AnfNode is a valuenode having some Primitive value
580 MS_CORE_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value);
581 
582 // Check whether two primitives are same.
583 bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2);
584 
585 // Get number of AbstractMonad
586 size_t GetAbstractMonadNum(const AbstractBasePtrList &args);
587 
588 // Check whether the given node has monad abstract.
589 bool HasAbstractMonad(const AnfNodePtr &node);
590 
591 // Check whether the given node has U monad abstract.
592 bool HasAbstractUMonad(const AnfNodePtr &node);
593 
594 // Check whether the given node has IO monad abstract.
595 bool HasAbstractIOMonad(const AnfNodePtr &node);
596 
597 // Gets primitive attribute value as a bool flag.
598 bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr);
599 
600 // Gets effect info from a primitive by its attributes.
601 EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim);
602 
603 struct MonadState {
604   AnfNodePtr u{nullptr};
605   AnfNodePtr io{nullptr};
606 };
607 
608 // Get Memory/IO monad state from node.
609 MonadState GetMonadState(const AnfNodePtr &node, const AnfNodePtr &skip_input = nullptr);
610 
611 // Check if two state is equivalent.
612 bool IsStateEquivalent(const MonadState &state1, const MonadState &state2);
613 
614 // Check if monad state is strict equivalent for the connected two nodes.
615 bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
616 
617 // Check if monad state is equivalent for the connected two nodes, not strict but more faster.
618 bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
619 
620 // used to check whether a ValueNode has some kind of value
621 template <typename T>
IsValueNode(const AnfNodePtr & node)622 static bool IsValueNode(const AnfNodePtr &node) {
623   MS_EXCEPTION_IF_NULL(node);
624   auto anode = node->cast<ValueNodePtr>();
625   if (anode != nullptr) {
626     auto value = anode->value();
627     if (value == nullptr) {
628       MS_LOG(EXCEPTION) << "Const value is nullptr.";
629     }
630     return value->isa<T>();
631   }
632   return false;
633 }
634 
GetValueNode(const AnfNodePtr & node)635 inline ValuePtr GetValueNode(const AnfNodePtr &node) {
636   MS_EXCEPTION_IF_NULL(node);
637   if (!node->isa<ValueNode>()) {
638     return nullptr;
639   }
640   return node->cast<ValueNodePtr>()->value();
641 }
642 
643 template <typename S,
644           typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
645                                   S>::type * = nullptr>
GetValueNode(const AnfNodePtr & node)646 inline S GetValueNode(const AnfNodePtr &node) {
647   auto value = GetValueNode(node);
648   if (value == nullptr) {
649     return nullptr;
650   }
651   auto s = value->cast<S>();
652   return s;
653 }
654 
655 size_t NewSeenGeneration();
656 
657 namespace id_generator {
658 std::string get_id(const AnfNodePtr &node);
659 void reset_id();
660 }  // namespace id_generator
661 using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
662 using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
663 std::string GetCNodeTarget(const AnfNodePtr &node);
664 std::string GetOriginNodeTarget(const AnfNodePtr &node);
665 bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
666 struct GraphSegment {
GraphSegmentGraphSegment667   GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
AddPreSegmentGraphSegment668   void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); }
669   std::vector<AnfNodePtr> nodes_;
670   std::set<std::shared_ptr<GraphSegment>> pre_segments_;
671   bool is_cut_{false};
672   uint32_t graph_id_{0};
673 };
674 using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
675 }  // namespace mindspore
676 
677 #endif  // MINDSPORE_CORE_IR_ANF_H_
678