• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_ANF_H_
20 #define MINDSPORE_CORE_IR_ANF_H_
21 
22 #include <functional>
23 #include <string>
24 #include <vector>
25 #include <memory>
26 #include <utility>
27 #include <set>
28 #include <bitset>
29 #include "utils/hash_map.h"
30 #include "utils/hash_set.h"
31 #include "base/base.h"
32 #include "base/effect_info.h"
33 #include "ir/kernel_info_dev.h"
34 #include "ir/scope.h"
35 #include "ir/primal_attr.h"
36 #include "ir/primal_debug_info.h"
37 #include "utils/info.h"
38 #include "utils/hashing.h"
39 #include "utils/ms_utils.h"
40 #include "utils/os.h"
41 
42 // A MindSpore ANF IR defined here.
43 // with BNF followed:
44 // <ANode> ::= Scalar | Named | Tensor  | Var |
45 //             Prim   | MetaFuncGraph | FuncGraph | Type|
46 //             Shape  | Param
47 // <CNode> ::= (<ANode> ...)
48 // <AnfNode> ::= <CNode> | <ANode>
49 // ANode: Atomic  Node
50 // CNode: Complex Node
51 namespace mindspore {
52 namespace abstract {
53 class BaseShape;
54 class AbstractBase;
55 }  // namespace abstract
56 using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
57 using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
58 using AbstractBasePtrList = std::vector<AbstractBasePtr>;
59 using NodeDebugInfoSet = std::set<NodeDebugInfoPtr, DebugInfoCompare>;
60 using SeenNum = uint32_t;
61 
62 class Value;
63 using ValuePtr = std::shared_ptr<Value>;
64 using ValuePtrList = std::vector<ValuePtr>;
65 
66 class ValueNode;
67 using ValueNodePtr = std::shared_ptr<ValueNode>;
68 
69 class CNode;
70 using CNodePtr = std::shared_ptr<CNode>;
71 using CNodePtrList = std::vector<CNodePtr>;
72 using CNodeWeakPtr = std::weak_ptr<CNode>;
73 
74 class FuncGraph;
75 using FuncGraphSet = OrderedSet<FuncGraphPtr>;
76 using FuncGraphVector = std::vector<FuncGraphPtr>;
77 
78 class Primitive;
79 using PrimitivePtr = std::shared_ptr<Primitive>;
80 struct PrimitiveHasher;
81 struct PrimitiveEqual;
82 using PrimitiveSet = mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
83 
84 class BaseRef;
85 
86 class Var;
87 using VarPtr = std::shared_ptr<Var>;
88 
89 class AnfIrVisitor;
90 
91 class ParamInfo;
92 using ParamInfoPtr = std::shared_ptr<ParamInfo>;
93 
94 // AnfNode is the basic class of the IR definition derived from Base.
95 // Only two types of nodes are derived: CNode and ANode.
96 // Methods:
97 // func_graph: return FuncGraph that this AnfNode belongs to.
98 // scope: return the scope namespace of this AnfNode. Set it using set_scope.
99 // abstract: return the cached inferred abstract value. It contains type, shape
100 // value. Set New cache using set_abstract.
101 // Type/Shape: return the related info of this AnfNode. When this AnfNode is an
102 // input of other CNodes, you can get the related info by this method.
103 // debug_info: return the information retrieved from parser. Set it using set_debug_info.
104 // fullname_with_scope: return the detailed debug info.
105 
106 /// \brief AnfNode is the basic class of the IR definition derived from Base.
107 class MS_CORE_API AnfNode : public Base {
108  public:
109   /// \brief Constructor.
110   ///
111   /// \param[in] func_graph The FuncGraph to which this AnfNode belongs.
112   /// \param[in] debug_info The debug info to be used for this AnfNode.
113   AnfNode(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info);
114 
115   /// \brief Constructor.
116   ///
117   /// \param[in] func_graph The FuncGraph to which this AnfNode belongs.
118   explicit AnfNode(const FuncGraphPtr &func_graph);
119 
120   /// \brief Destructor.
121   ~AnfNode() override = default;
122   MS_DECLARE_PARENT(AnfNode, Base);
123 
124   /// \brief Use the method of the AnfIrVisitor class to process the node.
125   virtual void accept(AnfIrVisitor *);
126 
127   /// \brief Obtain the FuncGraph to which this AnfNode belongs.
128   ///
129   /// \return The FuncGraph to which this AnfNode belongs.
130   FuncGraphPtr func_graph() const;
131 
132   /// \brief Set the FuncGraph to which this AnfNode belongs.
133   ///
134   /// \param[in] func_graph The input FuncGraph.
135   virtual void set_func_graph(const FuncGraphPtr &func_graph);
136 
137   /// \brief Obtain the scope namespace of this AnfNode.
138   ///
139   /// \return The scope namespace.
140   ScopePtr scope();
141 
142   /// \brief Set the scope namespace of this AnfNode.
143   ///
144   /// \param[in] scope New scope namespace.
145   void set_scope(const ScopePtr &scope);
146 
147   /// \brief Obtain device kernel program information.
148   ///
149   /// \return Device kernel program information.
150   const KernelInfoDevice *kernel_info() const;
151 
152   /// \brief Obtain device kernel program information.
153   ///
154   /// \return Device kernel program information.
155   KernelInfoDevice *kernel_info();
156 
157   /// \brief Obtain the pointer of KernelInfoDevice.
158   ///
159   /// \return The pointer of KernelInfoDevice.
160   KernelInfoDevicePtr kernel_info_ptr() const;
161 
162   /// \brief Set device kernel program information.
163   ///
164   /// \param[in] kernel_info New device kernel program information.
165   void set_kernel_info(const KernelInfoDevicePtr &kernel_info);
166 
167   /// \brief Obtain the inferred abstract value of this AnfNode.
168   ///
169   /// \return The inferred abstract value.
170   const AbstractBasePtr &abstract() const;
171 
172   /// \brief Set the abstract value of this AnfNode.
173   ///
174   /// \param[in] abs New abstract value.
175   void set_abstract(const AbstractBasePtr &abs);
176 
177   /// \brief Obtain the debugging information of this AnfNode.
178   ///
179   /// \return The debugging information of this AnfNode.
180   NodeDebugInfoPtr debug_info();
181 
182   /// \brief Set the debugging information of this AnfNode.
183   ///
184   /// \return New debugging information.
185   virtual void set_debug_info(const NodeDebugInfoPtr &debug_info);
186 
187   /// \brief Obtain the type of the element in this AnfNode.
188   ///
189   /// \return The type of the element.
190   TypePtr Type() const;
191 
192   /// \brief Obtain the shape of the element in this AnfNode.
193   ///
194   /// \return The shape of the element.
195   BaseShapePtr Shape() const;
196 
197   std::size_t hash() const final;
198 
199   /// \brief Obtain detailed information about scope namespace.
200   ///
201   /// \return Detailed information about scope namespace.
202   virtual std::string fullname_with_scope();
203 
204   /// \brief Obtain the unique name of this AnfNode.
205   ///
206   /// \return The unique name of this AnfNode.
207   std::string UniqueName();
208 
209   /// \brief Obtain the display information of this AnfNode.
210   ///
211   /// \param[in] recursive_level Recursion level when displayed.
212   /// \return Information to be displayed.
213   virtual std::string DebugString(int recursive_level = 1) const;
214 
215   /// \brief Obtain the display information of this AnfNode.
216   ///
217   /// \param[in] recursive Whether to display AnfNode recursively.
218   /// \return Information to be displayed.
219   virtual std::string DebugString(bool recursive) const;
220 
221   std::string ToString() const override;
222 
223   void dump() const override;
224 
225   /// \brief Obtain the unique id of the debug information of this AnfNode.
226   ///
227   /// \return Unique id.
228   std::string UniqueId();
229 
230   /// \brief Obtain the unique id through copied traced information.
231   ///
232   /// \return Unique id.
233   std::string UniqueIdThroughCopy();
234 
235   /// \brief Determine whether two AnfNodes are the same.
236   ///
237   /// \param[in] other Another ANfNode.
238   /// \return True if the same, otherwise False.
239   virtual bool operator==(const AnfNode &other) const;
240 
241   /// \brief Obtain the display information of this AnfNode.
242   ///
243   /// \param[in] os Output stream.
244   /// \param[in] node AnfNode to be displayed.
245   /// \return Output stream.
246   friend std::ostream &operator<<(std::ostream &os, const AnfNode &node);
247 
248   /// \brief Check if there is an interpret node.
249   ///
250   /// \return True if there is an interpret node, otherwise false.
251   bool interpret() const;
252 
253   /// \brief Whether to use interpretation
254   ///
255   /// \param[in] interpret Boolean.
256   void set_interpret(const bool &interpret);
257 
258   /// \brief Check if there is an interpret node related to the unsupported internal type.
259   ///
260   /// \return True if there is an interpret node related to the unsupported internal type, otherwise false.
261   bool interpret_internal_type();
262 
263   /// \brief Whether there is an interpret node with unsupported internal type.
264   ///
265   /// \param[in] interpret_internal_type Boolean.
266   void set_interpret_internal_type(const bool &interpret_internal_type);
267 
268   SeenNum seen_{0};
269   SeenNum extra_seen_{0};
270 
271  protected:
272   // Hold a weak ref to Graph as Graph also hold ref to AnfNode.
273   // Otherwise, func_graph_ and AnfNode will make a reference cycle.
274   FuncGraphWeakPtr func_graph_;
275   AbstractBasePtr abstract_;
276   NodeDebugInfoPtr debug_info_;
277   std::string fullname_with_scope_;
278 
279  private:
280   static constexpr size_t kInterpret = 0;
281   static constexpr size_t kInterpretInternalType = 1;
282   static constexpr size_t kNumInterpretFlags = 2;
283   static constexpr auto kKernelInfoKey = "kernel_info";
284 
285   ScopePtr scope_;
286   std::bitset<kNumInterpretFlags> interpret_flags_;
287 };
288 
289 // CNode represents the complex node with a set of arguments.
290 // Fields:
291 // inputs_: represents all of the inputs for this CNode.
292 // weak_inputs_: represents all of the weak inputs for this CNode.
293 // Using input(i) to get the index i input.
294 // Using inputs() to get all the inputs as a vector.
295 // Using add_input(input) to append a new input for a CNode.
296 // Using set_input(i, input) to change some input of these inputs.
297 // Using set_inputs(inputs) to refresh all of the inputs of a CNode.
298 // func_graph_as_var: used in opt pattern matching to match a real FuncGraph.
299 // stop_gradient: a flag used to stop gradient.
300 // Using stop_gradient() to get this flag, mainly used in ad.
301 // Using set_stop_gradient() to set this flag.
302 class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
303  public:
304   /// \brief Constructor.
305   ///
306   /// \param[in] inputs Input nodes of this Cnode.
307   /// \param[in] func_graph The FuncGraph to which this CNode belongs.
308   CNode(const AnfNodePtrList &inputs, const FuncGraphPtr &func_graph);
309 
310   /// \brief Constructor.
311   ///
312   /// \param[in] inputs Input nodes of this Cnode.
313   /// \param[in] func_graph_as_var The FuncGraph of type VarPtr to which this CNode belongs,
314   CNode(const AnfNodePtrList &inputs, const VarPtr &func_graph_as_var);
315 
316   /// \brief Constructor.
317   ///
318   /// \param[in] weak_inputs Input nodes of this Cnode.
319   /// \param[in] func_graph The FuncGraph to which this CNode belongs.
320   CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph);
321 
322   /// \brief Constructor.
323   ///
324   /// \param[in] weak_inputs Input nodes of this Cnode.
325   /// \param[in] func_graph The FuncGraph to which this CNode belongs.
326   CNode(const AnfNodeWeakPtrList &weak_inputs, const FuncGraphPtr &func_graph);
327 
328   /// \brief Constructor.
329   ///
330   /// \param[in] weak_inputs Input nodes of this Cnode.
331   /// \param[in] func_graph The FuncGraph to which this CNode belongs.
332   /// \param[in] debug_info The debug info to be used for this CNode.
333   CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info);
334 
335   /// \brief Destructor.
336   ~CNode() override = default;
337   MS_DECLARE_PARENT(CNode, AnfNode);
338 
339   void accept(AnfIrVisitor *v) override;
340 
341   /// \brief Check whether this cnode has the same primitive value as the first input.
342   ///
343   /// \return True if they have the same primitive value, otherwise false.
344   bool IsApply(const PrimitivePtr &value) const;
345 
346   /// \brief Obtain the size of input nodes of this CNode.
347   ///
348   /// \return Size of input nodes.
349   const size_t size() const;
350 
351   /// \brief Check if this CNode has no input node.
352   ///
353   /// \return if input nodes is empty.
354   const bool empty() const;
355 
356   /// \brief Get the input node of the given index.
357   /// The weak_input() is recommended.
358   ///
359   /// \param[in] i The given index.
360   /// \return The input node of the given index.
361   const AnfNodePtr input(size_t i) const;
362 
363   /// \brief Get the input node of the given index.
364   ///
365   /// \param[in] i The given index.
366   /// \return The input node of the given index.
367   const AnfNodeWeakPtr &weak_input(size_t i) const;
368 
369   /// \brief Get the input nodes.
370   /// The weak_inputs() is recommended.
371   ///
372   /// \return The input nodes of this CNode.
373   const AnfNodePtrList &inputs();
374 
375   /// \brief Get the input nodes.
376   ///
377   /// \return The input nodes of this CNode.
378   const AnfNodeWeakPtrList &weak_inputs() const;
379 
380   /// \brief Add the input node to this CNode.
381   ///
382   /// \param[in] input Node.
383   void add_input(const AnfNodePtr &input);
384 
385   /// \brief Set the input node of the given index.
386   ///
387   /// \param[in] i The given index.
388   /// \param[in] input Node.
389   void set_input(size_t i, const AnfNodePtr &new_input);
390 
391   /// \brief Set the input nodes for this CNode.
392   ///
393   /// \param[in] inputs Input nodes.
394   void set_inputs(const AnfNodePtrList &inputs);
395 
396   /// \brief Set the input nodes for this CNode.
397   ///
398   /// \param[in] weak_inputs Input nodes.
399   void set_weak_inputs(const AnfNodeWeakPtrList &weak_inputs);
400 
401   // output_value store cnode value and id in pynative mode.
402   using OutputValue = std::pair<ValueNodePtr, std::string>;
403 
404   /// \brief Record the cnode value and id to output_value_.
405   ///
406   /// \param[in] forward The cnode value.
407   /// \param[in] id The id.
408   void set_forward(const ValueNodePtr &forward, const std::string &id);
409 
410   /// \brief Get the record of output value of this CNode.
411   ///
412   /// \return The output value of this CNode.
413   const OutputValue &forward() const;
414 
415   /// \brief Check if stop_gradient is set.
416   ///
417   /// \return True if stop_gradient is set, otherwise false.
418   bool stop_gradient() const;
419 
420   /// \brief Set stop_gradient.
421   ///
422   /// \param[in] stop_gradient Boolean.
423   void set_stop_gradient(bool stop_gradient);
424 
425   std::string fullname_with_scope() override;
426 
427   /// \brief Set fullname_with_scope for this CNode.
428   ///
429   /// \param[in] full_name The fullname_with_scope.
430   void set_fullname_with_scope(const std::string full_name);
431 
432   std::string DebugString(int recursive_level = 1) const override;
433   std::string DebugString(bool recursive) const override;
434 
435   /// \brief Set in_forward_flag for this CNode.
436   ///
437   /// \param[in] flag Boolean.
438   void set_in_forward_flag(bool flag);
439   /// \brief Check if in_forward_flag is set.
440   ///
441   /// \return True if in_forward_flag is set, otherwise false.
442   bool in_forward_flag() const;
443 
444   /// \brief Check if the primitive of this CNode is load.
445   ///
446   /// \param[in] is_load Boolean.
447   void set_load_flag(bool is_load);
448   /// \brief Check if is_load_ is set.
449   ///
450   /// \return True if is_load_ is set, otherwise false.
451   bool get_load_flag() const;
452 
453   /// \brief Get func_graph_as_var of this CNode.
454   ///
455   /// \return func_graph_as_var.
456   VarPtr func_graph_as_var() const;
457 
458   /// \brief Get all attributes of this CNode.
459   ///
460   /// \return Attributes of this CNode.
461   const mindspore::HashMap<std::string, ValuePtr> &attrs() const;
462   void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs);
463 
464   /// \brief Add a new attribute to this CNode.
465   ///
466   /// \param[in] name The name of the new attribute.
467   /// \param[in] attr The value of the new attribute.
468   void AddAttr(const std::string &name, const ValuePtr &attr);
469 
470   /// \brief Erase the attribute with the given name.
471   ///
472   /// \param[in] name The name of attribute.
473   void EraseAttr(const std::string &name);
474 
475   /// \brief Get the attribute with the given name.
476   ///
477   /// \param[in] name The name of attribute.
478   /// \return Attribute.
479   ValuePtr GetAttr(const std::string &name) const;
480 
481   /// \brief Check whether this CNode has an attribute with the given name.
482   ///
483   /// \param[in] name The name of attribute.
484   /// \return Boolean.
485   bool HasAttr(const std::string &name) const;
486 
487   /// \brief Get the number of input tensors.
488   ///
489   /// \return The number of input tensors.
490   ssize_t input_tensor_num() const;
491 
492   /// \brief Get the primal attributes of this CNode.
493   ///
494   /// \return The primal attributes.
495   const mindspore::HashMap<std::string, ValuePtr> &primal_attrs() const;
496 
497   /// \brief Set the primal attributes of this CNode.
498   ///
499   /// \param[in] attrs The primal attributes.
500   void set_primal_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs);
501 
502   /// \brief Add the primal attribute to this CNode.
503   ///
504   /// \param[in] name The name of the attribute.
505   /// \param[in] attr The attribute.
506   void AddPrimalAttr(const std::string &name, const ValuePtr &attr);
507 
508   /// \brief Erase the primal attribute with the given name.
509   ///
510   /// \param[in] name The name of the attribute.
511   void ErasePrimalAttr(const std::string &name);
512 
513   /// \brief Get the primal attribute with the given name.
514   ///
515   /// \param[in] name The name of the attribute.
516   /// \return The primal attribute with the given name.
517   ValuePtr GetPrimalAttr(const std::string &name) const;
518 
519   /// \brief Check whether this CNode has an attribute with the given name.
520   ///
521   /// \param[in] name The name of the attribute.
522   /// \return True if it exists, otherwise false.
523   bool HasPrimalAttr(const std::string &name) const;
524 
525   /// \brief Get primal debug information.
526   ///
527   /// \return The primal debug information.
528   NodeDebugInfoSet primal_debug_infos() const;
529 
530   /// \brief Set primal debug information.
531   ///
532   /// \param[in] debug_infos Debug information of this CNode.
533   void set_primal_debug_infos(const NodeDebugInfoSet &debug_infos);
534 
535   /// \brief Add a primal debug information.
536   ///
537   /// \param[in] debug_info A debug information.
538   void AddPrimalDebugInfo(const NodeDebugInfoPtr &debug_info);
539 
540   void CloneCNodeInfo(const CNodePtr &node);
541 
542   /// \brief Set the number of input tensors.
543   ///
544   /// \param[in] The number of input tensors.
545   void set_input_tensor_num(ssize_t input_tensor_num);
546 
547   /// \brief Is effect have been handled.
548   ///
549   /// \return True if effect have been handled, otherwise false.
550   bool IsEffectHandled() const;
551 
552   /// \brief Set effect handled or not.
553   ///
554   /// \param[in] handled Boolean.
555   void SetEffectHandled(bool handled);
556 
557   /// \brief Get the debug infos of fused nodes.
558   ///
559   /// \return A vector of debug infos.
560   NodeDebugInfoSet fused_debug_infos() const;
561 
562   /// \brief Set the debug infos for CNode.
563   ///
564   /// \param fused_debug_infos The debug infos to be set.
565   void set_fused_debug_infos(const NodeDebugInfoSet &fused_debug_infos);
566 
567   /// \brief Add a node's debug info or fused debug info.
568   ///
569   /// \param node An anf node.
570   void AddFusedDebugInfo(const AnfNodePtr &node);
571 
572   /// \brief Add a vector of nodes' debug info or fused debug info.
573   ///
574   /// \param nodes A vector of anf nodes.
575   void AddFusedDebugInfoList(const AnfNodePtrList &nodes);
576 
577   /// \brief Add a node debug info.
578   ///
579   /// \param debug_info A node debug info of an anf node.
580   void AddFusedDebugInfo(const NodeDebugInfoPtr &debug_info);
581 
582   /// \brief Add a list of node debug infos.
583   ///
584   /// \param debug_infos A node debug info of an anf node.
585   void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);
586 
587   /// \brief Check whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
588   ///
589   /// \return True if contains, otherwise false.
590   bool has_side_effect_node() const;
591 
592   /// \brief Set whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
593   ///
594   /// \param[in] has_side_effect_node Boolean.
595   void set_has_side_effect_node(bool has_side_effect_node);
596 
597   /// \brief Set the debugging information of this AnfNode.
598   ///
599   /// \return New debugging information.
600   void set_debug_info(const NodeDebugInfoPtr &debug_info) override;
601 
602  private:
603   void Init();
604   void CheckCNodeWeakInput();
605 
606   static constexpr size_t kStopGradient = 0;
607   static constexpr size_t kInForwardFlag = 1;
608   static constexpr size_t kEffectHandled = 2;
609   static constexpr size_t kIsLoad = 3;
610   static constexpr size_t kNumFlags = 4;
611   static constexpr auto kFuncGraphVarKey = "fg_var";
612   static constexpr auto kOutputValueKey = "out_value";
613 
614   AnfNodeWeakPtrList weak_inputs_;
615   // Most of the time it's empty.
616   // It's set after call CNode::inputs() for compatibility with earlier version.
617   AnfNodePtrList inputs_;
618   bool inputs_bound_{false};
619 
620   ssize_t input_tensor_num_ = -1;
621   std::bitset<kNumFlags> flags_;
622 
623   mindspore::HashMap<std::string, ValuePtr> attrs_;
624   mindspore::HashMap<std::string, ValuePtr> primal_attrs_;
625   NodeDebugInfoSet primal_debug_infos_;
626   NodeDebugInfoSet fused_debug_infos_;
627 
628   // If the inputs or their inputs contain Depend CNode with isolated side-effect node.
629   bool has_side_effect_node_{false};
630 };
631 
632 // ANode represents the atomic node. It's derived Parameter and ValueNode.
633 class MS_CORE_API ANode : public AnfNode {
634  public:
ANode()635   ANode() : AnfNode(nullptr) {}
636 
637   /// \brief Constructor.
638   ///
639   /// \param[in] func_graph The FuncGraph to which this ANode belongs.
ANode(const FuncGraphPtr & func_graph)640   explicit ANode(const FuncGraphPtr &func_graph) : AnfNode(func_graph) {}
641 
642   /// \brief Constructor.
643   ///
644   /// \param[in] func_graph The FuncGraph to which this ANode belongs.
645   /// \param[in] debug_info The debug info to be used for this ANode.
ANode(const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)646   ANode(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info) : AnfNode(func_graph, std::move(debug_info)) {}
647 
648   /// \brief Destructor.
649   virtual ~ANode() = default;
650 
651   MS_DECLARE_PARENT(ANode, AnfNode);
652 };
653 
654 // Parameter represents the parameter inputs of a function. They have no value.
655 // Attributes:
656 // default_param_value_: used to hold the inputting tensor of the model.
657 class MS_CORE_API Parameter final : public ANode {
658  public:
659   explicit Parameter(const FuncGraphPtr &func_graph);
660 
661   Parameter(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info);
662 
663   /// \brief Destructor.
664   ~Parameter() override = default;
665   MS_DECLARE_PARENT(Parameter, ANode);
666 
667   void accept(AnfIrVisitor *v) override;
668   std::string DebugString(int recursive_level = 1) const override;
669 
670   /// \brief Get the name of this Parameter.
671   ///
672   /// \return The name.
673   std::string name() const;
674 
675   /// \brief Set the name of this Parameter.
676   ///
677   /// \param[in] The name.
678   void set_name(const std::string &name);
679 
680   std::string fullname_with_scope() override;
681 
682   /// \brief Check if there is a default parameter.
683   ///
684   /// \return True if this Parameter has a default parameter, otherwise false.
685   bool has_default() const;
686 
687   /// \brief Set the default parameter.
688   ///
689   /// \param[in] param The default parameter.
690   void set_default_param(const ValuePtr &param);
691 
692   /// \brief Get the default parameter.
693   ///
694   /// \return The default parameter.
695   const ValuePtr &default_param() const;
696 
697   /// \brief Get the parameter information.
698   ///
699   /// \return The parameter information.
700   ParamInfoPtr param_info() const;
701 
702   /// \brief Increase used_graph_count.
703   void IncreaseUsedGraphCount();
704   /// \brief Decrease used_graph_count.
705   void DecreaseUsedGraphCount();
706   /// \brief Get used_graph_count.
707   ///
708   /// \return used_graph_count.
709   int used_graph_count() const;
710 
711   bool is_top_graph_param() const;
712   void set_is_top_graph_param(bool flag);
713 
714   bool operator==(const AnfNode &other) const override;
715 
716   /// \brief This parameter is not used in graph with id.
717   ///
718   /// \param[in] graph_id The graph id.
719   void SetNotUsedByRealKernelInGraph(uint32_t graph_id);
720 
721   /// \brief Check if this Parameter is used in graph with id.
722   ///
723   /// \param[in] graph_id True if used, otherwise false.
724   bool IsUsedByRealKernelInGraph(uint32_t graph_id) const;
725 
726   /// \brief Set whether this Parameter has a dynamic shape.
727   ///
728   /// \param[in] flag Boolean.
729   void set_has_dynamic_shape(bool flag);
730 
731   /// \brief Check whether this Parameter has a dynamic shape.
732   ///
733   /// \return True if this Parameter has a dynamic shape, otherwise false.
734   bool has_dynamic_shape() const;
735 
736   /// \brief Set whether this Parameter is dynamic len.
737   ///
738   /// \param[in] flag Boolean.
739   void set_dynamic_len(bool flag);
740 
741   /// \brief Check whether this Parameter is dynamic len.
742   ///
743   /// \return True if this Parameter is dynamic len, otherwise false.
744   bool dynamic_len() const;
745 
746   /// \brief Set groups attr in FRACTAL_Z format.
747   ///
748   /// \param[in] fracz_group Groups attr in FRACTAL_Z format.
749   void set_fracz_group(int64_t fracz_group);
750 
751   /// \brief Get groups attr in FRACTAL_Z format.
752   ///
753   /// \return Groups attr in FRACTAL_Z format.
754   int64_t fracz_group() const;
755 
756   /// \brief Set input_size attr in FracNZ_RNN or ND_RNN_Bias format.
757   ///
758   /// \param[in] input_size input_size attr in FracNZ_RNN or ND_RNN_Bias format.
759   void set_input_size(int64_t input_size);
760 
761   /// \brief Get input_size attr in FracNZ_RNN or ND_RNN_Bias format.
762   ///
763   /// \return input_size attr in FracNZ_RNN or ND_RNN_Bias format.
764   int64_t input_size() const;
765 
766   /// \brief Set hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
767   ///
768   /// \param[in] hidden_size hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
769   void set_hidden_size(int64_t hidden_size);
770 
771   /// \brief Get hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
772   ///
773   /// \return hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
774   int64_t hidden_size() const;
775 
776   /// \brief Set the debugging information of this AnfNode.
777   ///
778   /// \return New debugging information.
779   void set_debug_info(const NodeDebugInfoPtr &debug_info) override;
780 
781  private:
782   void Init();
783   struct FormatAttr {
784     int64_t fracz_group = 1;
785     int64_t input_size = 0;
786     int64_t hidden_size = 0;
787   };
788   std::string name_;
789   ValuePtr default_param_;
790   // Some attrs used in special format.
791   FormatAttr format_attrs_;
792   std::set<uint32_t> not_used_in_graphs_;
793   int used_graph_count_ = 0;
794   bool has_default_ = false;
795   bool has_dynamic_shape_ = false;
796   // Dynamic len is a flag indicating whether the parameter is dynamic sequence.
797   bool is_dynamic_len_ = false;
798   bool is_top_graph_param_ = false;
799 };
800 using ParameterPtr = std::shared_ptr<Parameter>;
801 using ParameterWeakPtr = std::weak_ptr<Parameter>;
802 
803 // Value is used to represent the atomic expression mentioned in BNF.
804 // It mainly be stored in ValueNode. Value and ValueNode is related definition.
805 class MS_CORE_API Value : public Base {
806  public:
807   /// \brief Default constructor.
808   Value() = default;
809 
810   /// \brief Constructor of Value.
811   ///
812   /// \param[in] t The type of this Value.
813   explicit Value(const TypePtr t);
814 
815   /// \brief Constructor of Value.
816   ///
817   /// \param[in] other Another Value.
818   Value(const Value &other);
819 
820   /// \brief Destructor.
821   ~Value() override = default;
822   MS_DECLARE_PARENT(Value, Base)
823 
824   /// \brief Get the type of this Value.
825   ///
826   /// \return The type.
827   TypePtr type() const;
828 
829   /// \brief Get the abstract value of Value.
830   ///
831   /// \return Abstract value of Value.
832   virtual abstract::AbstractBasePtr ToAbstract();
833 
834   /// \brief Check whether the input is the current Value object.
835   ///
836   /// \param[in] rhs The Value object to be compared.
837   /// \return Whether the input is the current Value object.
838   virtual bool operator==(const Value &rhs) const = 0;
839 
840   /// \brief Check whether the input is the current Value object.
841   ///
842   /// \param[in] other The Value object to be compared.
843   /// \return Whether the input is the current Value object.
844   Value &operator=(const Value &other);
845 
846   virtual bool ContainsValueAny() const;
847 
848  protected:
849   TypePtr type_{nullptr};
850 };
851 
852 // ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
853 // does not belong to any particular function graph.
854 class MS_CORE_API ValueNode final : public ANode {
855  public:
856   /// \brief Constructor of ValueNode.
857   ///
858   /// \param[in] value The value of this ValueNode.
859   explicit ValueNode(const ValuePtr &value);
860 
861   /// \brief Constructor of ValueNode.
862   ///
863   /// \param[in] value The value of this ValueNode.
864   /// \param[in] debug_info The debug info to be used for this ValueNode.
865   ValueNode(const ValuePtr &value, NodeDebugInfoPtr &&debug_info);
866 
867   /// \brief Destructor.
868   ~ValueNode() override = default;
869   MS_DECLARE_PARENT(ValueNode, ANode);
870 
871   void set_func_graph(const FuncGraphPtr &) override;
872 
873   void accept(AnfIrVisitor *v) override;
874 
875   /// \brief Set the value of this ValueNode.
876   ///
877   /// \param[in] value The value.
878   void set_value(const ValuePtr &value);
879 
880   /// \brief Get the value of this ValueNode.
881   ///
882   /// \return The value.
883   const ValuePtr &value() const;
884 
885   std::string fullname_with_scope() override;
886 
887   /// \brief Set whether this ValueNode has a new value.
888   ///
889   /// \param[in] flag Whether this ValueNode has a new value.
890   void set_has_new_value(bool flag);
891 
892   /// \brief Check whether this ValueNode has a new value.
893   ///
894   /// \return Whether this ValueNode has a new value.
895   bool has_new_value() const;
896 
897   /// \brief Get the count of graphs using this ValueNode.
898   ///
899   /// \return The count of graphs using this ValueNode.
900   size_t used_graph_count() const;
901 
902   /// \brief Set the count of groups using this ValueNode.
903   ///
904   /// \param[in] group The count of groups using this ValueNode.
905   void set_fracz_group(int64_t group);
906 
907   /// \brief Get groups attr in FRACTAL_Z format.
908   ///
909   /// \return Groups attr in FRACTAL_Z format.
910   int64_t fracz_group() const;
911 
912   /// \brief Set the count of graphs using this ValueNode.
913   ///
914   /// \param[in] used_graph_count The count of graphs using this ValueNode.
915   void set_used_graph_count(size_t used_graph_count);
916 
917   std::string ToString() const override;
918   std::string DebugString(int recursive_level = 1) const override;
919   std::string DebugString(bool recursive) const override;
920 
921   bool operator==(const AnfNode &other) const override;
922   friend std::ostream &operator<<(std::ostream &os, const ValueNodePtr &node) {
923     MS_EXCEPTION_IF_NULL(node);
924     os << node->ToString();
925     return os;
926   }
927 
928   /// \brief Set the debugging information of this AnfNode.
929   ///
930   /// \return New debugging information.
931   void set_debug_info(const NodeDebugInfoPtr &debug_info) override;
932 
933  private:
934   void Init();
935   struct FormatAttr {
936     int64_t fracz_group = 1;
937     int64_t input_size = 0;
938     int64_t hidden_size = 0;
939   };
940   FormatAttr format_attr_;
941   ValuePtr value_;
942   size_t used_graph_count_{0};
943   bool has_new_value_ = false;
944 };
945 
946 template <typename T>
947 struct ImmTraits {};
948 
949 #define IMM_TRAITS(typeimm, prototype) \
950   template <>                          \
951   struct ImmTraits<prototype> {        \
952     using type = typeimm;              \
953   };
954 
MakeValue(const ValuePtr & value)955 inline ValuePtr MakeValue(const ValuePtr &value) { return value; }
956 
957 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
MakeValue(S v)958 inline ValuePtr MakeValue(S v) {
959   return std::make_shared<U>(v);
960 }
961 
962 template <typename S, typename U = typename ImmTraits<S>::type>
GetValue(const ValuePtr & value)963 static S GetValue(const ValuePtr &value) {
964   MS_EXCEPTION_IF_NULL(value);
965   auto imm = value->cast_ptr<typename U::element_type>();
966   if (imm == nullptr) {
967     MS_LOG(INTERNAL_EXCEPTION) << "Cast failed, original value: " << value->ToString()
968                                << ", type: " << value->type_name();
969   }
970   return imm->value();
971 }
972 
973 template <typename S,
974           typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
975                                   S>::type * = nullptr>
GetValue(const ValuePtr & value)976 static S GetValue(const ValuePtr &value) {
977   MS_EXCEPTION_IF_NULL(value);
978   S v = value->cast<S>();
979   if (v == nullptr) {
980     MS_LOG(INTERNAL_EXCEPTION) << "Cast failed, original value: " << value->ToString()
981                                << ", type: " << value->type_name();
982   }
983   return v;
984 }
985 
986 MS_CORE_API std::string GetCNodeFuncName(const CNodePtr &cnode);
987 
988 // Used to get FuncGraphPtr from a cnode first input
989 MS_CORE_API FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node);
990 
991 // Used to check whether an AnfNode is a cnode with a kind of Primitive as first input.
992 MS_CORE_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
993 
994 // Used to check whether an AnfNode is a cnode with a kind of Primitive as first input.
995 // If the Primitive is DoSignature, get the real Primitive firstly.
996 MS_CORE_API bool IsPrimitiveCNodeWithoutDoSignature(const AnfNodePtr &node, const PrimitivePtr &value);
997 
998 // Used to get PrimitivePtr from a cnode first input
999 MS_CORE_API PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
1000 
1001 // Return the function Primitive if DoSignaturePrimitive,
1002 // otherwise return the Primitive directly.
1003 MS_CORE_API PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node);
1004 // Check the first input of CNode.
1005 // Return the function Primitive if DoSignaturePrimitive,
1006 // otherwise return the Primitive directly.
1007 MS_CORE_API PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node);
1008 
1009 // Return the function value if DoSignaturePrimitive,
1010 // otherwise return the value directly.
1011 MS_CORE_API ValuePtr GetValueWithoutDoSignature(const ValuePtr &value);
1012 // Return the function value if DoSignaturePrimitive,
1013 // otherwise return the value directly.
1014 MS_CORE_API ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node);
1015 // Check the first input of CNode.
1016 // Return the function value if DoSignaturePrimitive,
1017 // otherwise return the value directly.
1018 MS_CORE_API ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node);
1019 
1020 /// \brief Used to check whether the given node is a ValueNode with some Primitive value.
1021 ///
1022 /// \param[in] node The input node.
1023 /// \param[in] value Primitive value.
1024 /// \return Whether the given node is a ValueNode with some Primitive value.
1025 MS_CORE_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value);
1026 
1027 // Check whether the given node is a ValueNode belonging to a primitive set.
1028 MS_CORE_API bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set);
1029 
1030 /// \brief Used to check whether the given node is a CNode belonging to a primitive set.
1031 ///
1032 /// \param[in] node The input node.
1033 /// \param[in] prim_set Primitive set.
1034 /// \return Whether the given node is a CNode belonging to a primitive set.
1035 MS_CORE_API bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
1036 
1037 // Check whether two primitives are same.
1038 MS_CORE_API bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2);
1039 
1040 // Get number of AbstractMonad
1041 MS_CORE_API size_t GetAbstractMonadNum(const AbstractBasePtrList &args);
1042 
1043 // Check whether the given node has monad abstract.
1044 MS_CORE_API bool HasAbstractMonad(const AnfNodePtr &node);
1045 
1046 // Check whether the given node has U monad abstract.
1047 MS_CORE_API bool HasAbstractUMonad(const AnfNodePtr &node);
1048 
1049 // Check whether the given node has IO monad abstract.
1050 MS_CORE_API bool HasAbstractIOMonad(const AnfNodePtr &node);
1051 
1052 // Gets primitive attribute value as a bool flag.
1053 MS_CORE_API bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr);
1054 
1055 // Gets effect info from a primitive by its attributes.
1056 MS_CORE_API EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim);
1057 
1058 // Check if monad state is equivalent for the connected two nodes, not strict but more faster.
1059 MS_CORE_API bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
1060 
1061 // Check if the node is DeadNode.
1062 MS_CORE_API bool IsDeadNode(const AnfNodePtr &node);
1063 
1064 // Check if the node is PolyNode.
1065 MS_CORE_API bool IsPolyNode(const AnfNodePtr &node);
1066 
1067 // Used to check whether a ValueNode has some kind of value.
1068 template <typename T>
IsValueNode(const AnfNodePtr & node)1069 bool IsValueNode(const AnfNodePtr &node) {
1070   auto value_node = dyn_cast_ptr<ValueNode>(node);
1071   if (value_node == nullptr) {
1072     return false;
1073   }
1074   const auto &value = value_node->value();
1075   return (value != nullptr) && (value->isa<T>());
1076 }
1077 
GetValueNode(const AnfNodePtr & node)1078 inline ValuePtr GetValueNode(const AnfNodePtr &node) {
1079   auto value_node = dyn_cast_ptr<ValueNode>(node);
1080   return (value_node == nullptr) ? nullptr : value_node->value();
1081 }
1082 
GetValuePtr(const AnfNodePtr & node)1083 inline Value *GetValuePtr(const AnfNodePtr &node) {
1084   auto value_node = dyn_cast_ptr<ValueNode>(node);
1085   return (value_node == nullptr) ? nullptr : value_node->value().get();
1086 }
1087 
1088 template <typename S,
1089           typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
1090                                   S>::type * = nullptr>
GetValueNode(const AnfNodePtr & node)1091 inline S GetValueNode(const AnfNodePtr &node) {
1092   auto value = GetValuePtr(node);
1093   return (value == nullptr) ? nullptr : value->cast<S>();
1094 }
1095 
1096 template <typename S, typename std::enable_if<std::is_base_of<Value, S>::value, S>::type * = nullptr>
GetValuePtr(const AnfNodePtr & node)1097 inline S *GetValuePtr(const AnfNodePtr &node) {
1098   auto value_node = dyn_cast_ptr<ValueNode>(node);
1099   if (value_node == nullptr) {
1100     return nullptr;
1101   }
1102   const auto &value = value_node->value();
1103   return (value == nullptr) ? nullptr : value->cast_ptr<S>();
1104 }
1105 
1106 MS_CORE_API SeenNum NewSeenGeneration();
1107 
1108 namespace id_generator {
1109 MS_CORE_API std::string get_id(const std::string &front_string);
1110 MS_CORE_API void reset_id();
1111 MS_CORE_API void reset_id_with_offset();
1112 }  // namespace id_generator
1113 using TaggedNodeMap = mindspore::HashMap<AnfNodePtr, size_t>;
1114 using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
1115 MS_CORE_API std::string GetCNodeTarget(const AnfNodePtr &node);
1116 std::string GetOriginNodeTarget(const AnfNodePtr &node);
1117 MS_CORE_API bool ContainMultiTarget(const AnfNodePtrList &nodes);
1118 struct GraphSegment {
GraphSegmentGraphSegment1119   GraphSegment(const AnfNodePtrList &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
AddPreSegmentGraphSegment1120   void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); }
1121   AnfNodePtrList nodes_;
1122   std::set<std::shared_ptr<GraphSegment>> pre_segments_;
1123   bool is_cut_{false};
1124   uint32_t graph_id_{0};
1125 };
1126 using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
1127 
1128 constexpr auto kElementsUseFlagsKey = "elements_use_flags";
GetSequenceNodeElementsUseFlags(const AnfNodePtr & node)1129 inline std::shared_ptr<std::vector<bool>> GetSequenceNodeElementsUseFlags(const AnfNodePtr &node) {
1130   MS_EXCEPTION_IF_NULL(node);
1131   return node->template user_data<std::vector<bool>>(kElementsUseFlagsKey);
1132 }
1133 
SetSequenceNodeElementsUseFlags(const AnfNodePtr & node,const std::shared_ptr<std::vector<bool>> & flags)1134 inline void SetSequenceNodeElementsUseFlags(const AnfNodePtr &node, const std::shared_ptr<std::vector<bool>> &flags) {
1135   MS_EXCEPTION_IF_NULL(node);
1136   node->set_user_data(kElementsUseFlagsKey, flags);
1137 }
1138 
1139 // Set the sequence nodes' elements use flags to 'new_flag' at specific 'index' position.
1140 MS_CORE_API void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, std::size_t index, bool new_flag);
1141 // Set the sequence nodes' elements use flags all to 'new_flag'.
1142 MS_CORE_API void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag);
1143 // Set the sequence nodes' elements use flags all to 'new_flag' recursively.
1144 MS_CORE_API void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag);
1145 }  // namespace mindspore
1146 #endif  // MINDSPORE_CORE_IR_ANF_H_
1147