1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_IR_ANF_H_
20 #define MINDSPORE_CORE_IR_ANF_H_
21
22 #include <functional>
23 #include <string>
24 #include <vector>
25 #include <memory>
26 #include <utility>
27 #include <set>
28 #include <bitset>
29 #include "utils/hash_map.h"
30 #include "utils/hash_set.h"
31 #include "base/base.h"
32 #include "base/effect_info.h"
33 #include "ir/kernel_info_dev.h"
34 #include "ir/scope.h"
35 #include "ir/primal_attr.h"
36 #include "ir/primal_debug_info.h"
37 #include "utils/info.h"
38 #include "utils/hashing.h"
39 #include "utils/ms_utils.h"
40 #include "utils/os.h"
41
42 // A MindSpore ANF IR defined here.
43 // with BNF followed:
44 // <ANode> ::= Scalar | Named | Tensor | Var |
45 // Prim | MetaFuncGraph | FuncGraph | Type|
46 // Shape | Param
47 // <CNode> ::= (<ANode> ...)
48 // <AnfNode> ::= <CNode> | <ANode>
49 // ANode: Atomic Node
50 // CNode: Complex Node
51 namespace mindspore {
52 namespace abstract {
53 class BaseShape;
54 class AbstractBase;
55 } // namespace abstract
56 using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
57 using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
58 using AbstractBasePtrList = std::vector<AbstractBasePtr>;
59 using NodeDebugInfoSet = std::set<NodeDebugInfoPtr, DebugInfoCompare>;
60 using SeenNum = uint32_t;
61
62 class Value;
63 using ValuePtr = std::shared_ptr<Value>;
64 using ValuePtrList = std::vector<ValuePtr>;
65
66 class ValueNode;
67 using ValueNodePtr = std::shared_ptr<ValueNode>;
68
69 class CNode;
70 using CNodePtr = std::shared_ptr<CNode>;
71 using CNodePtrList = std::vector<CNodePtr>;
72 using CNodeWeakPtr = std::weak_ptr<CNode>;
73
74 class FuncGraph;
75 using FuncGraphSet = OrderedSet<FuncGraphPtr>;
76 using FuncGraphVector = std::vector<FuncGraphPtr>;
77
78 class Primitive;
79 using PrimitivePtr = std::shared_ptr<Primitive>;
80 struct PrimitiveHasher;
81 struct PrimitiveEqual;
82 using PrimitiveSet = mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
83
84 class BaseRef;
85
86 class Var;
87 using VarPtr = std::shared_ptr<Var>;
88
89 class AnfIrVisitor;
90
91 class ParamInfo;
92 using ParamInfoPtr = std::shared_ptr<ParamInfo>;
93
94 // AnfNode is the basic class of the IR definition derived from Base.
95 // Only two types of nodes are derived: CNode and ANode.
96 // Methods:
97 // func_graph: return FuncGraph that this AnfNode belongs to.
98 // scope: return the scope namespace of this AnfNode. Set it using set_scope.
99 // abstract: return the cached inferred abstract value. It contains type, shape
100 // value. Set New cache using set_abstract.
101 // Type/Shape: return the related info of this AnfNode. When this AnfNode is an
102 // input of other CNodes, you can get the related info by this method.
103 // debug_info: return the information retrieved from parser. Set it using set_debug_info.
104 // fullname_with_scope: return the detailed debug info.
105
106 /// \brief AnfNode is the basic class of the IR definition derived from Base.
107 class MS_CORE_API AnfNode : public Base {
108 public:
109 /// \brief Constructor.
110 ///
111 /// \param[in] func_graph The FuncGraph to which this AnfNode belongs.
112 /// \param[in] debug_info The debug info to be used for this AnfNode.
113 AnfNode(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info);
114
115 /// \brief Constructor.
116 ///
117 /// \param[in] func_graph The FuncGraph to which this AnfNode belongs.
118 explicit AnfNode(const FuncGraphPtr &func_graph);
119
120 /// \brief Destructor.
121 ~AnfNode() override = default;
122 MS_DECLARE_PARENT(AnfNode, Base);
123
124 /// \brief Use the method of the AnfIrVisitor class to process the node.
125 virtual void accept(AnfIrVisitor *);
126
127 /// \brief Obtain the FuncGraph to which this AnfNode belongs.
128 ///
129 /// \return The FuncGraph to which this AnfNode belongs.
130 FuncGraphPtr func_graph() const;
131
132 /// \brief Set the FuncGraph to which this AnfNode belongs.
133 ///
134 /// \param[in] func_graph The input FuncGraph.
135 virtual void set_func_graph(const FuncGraphPtr &func_graph);
136
137 /// \brief Obtain the scope namespace of this AnfNode.
138 ///
139 /// \return The scope namespace.
140 ScopePtr scope();
141
142 /// \brief Set the scope namespace of this AnfNode.
143 ///
144 /// \param[in] scope New scope namespace.
145 void set_scope(const ScopePtr &scope);
146
147 /// \brief Obtain device kernel program information.
148 ///
149 /// \return Device kernel program information.
150 const KernelInfoDevice *kernel_info() const;
151
152 /// \brief Obtain device kernel program information.
153 ///
154 /// \return Device kernel program information.
155 KernelInfoDevice *kernel_info();
156
157 /// \brief Obtain the pointer of KernelInfoDevice.
158 ///
159 /// \return The pointer of KernelInfoDevice.
160 KernelInfoDevicePtr kernel_info_ptr() const;
161
162 /// \brief Set device kernel program information.
163 ///
164 /// \param[in] kernel_info New device kernel program information.
165 void set_kernel_info(const KernelInfoDevicePtr &kernel_info);
166
167 /// \brief Obtain the inferred abstract value of this AnfNode.
168 ///
169 /// \return The inferred abstract value.
170 const AbstractBasePtr &abstract() const;
171
172 /// \brief Set the abstract value of this AnfNode.
173 ///
174 /// \param[in] abs New abstract value.
175 void set_abstract(const AbstractBasePtr &abs);
176
177 /// \brief Obtain the debugging information of this AnfNode.
178 ///
179 /// \return The debugging information of this AnfNode.
180 NodeDebugInfoPtr debug_info();
181
182 /// \brief Set the debugging information of this AnfNode.
183 ///
184 /// \return New debugging information.
185 virtual void set_debug_info(const NodeDebugInfoPtr &debug_info);
186
187 /// \brief Obtain the type of the element in this AnfNode.
188 ///
189 /// \return The type of the element.
190 TypePtr Type() const;
191
192 /// \brief Obtain the shape of the element in this AnfNode.
193 ///
194 /// \return The shape of the element.
195 BaseShapePtr Shape() const;
196
197 std::size_t hash() const final;
198
199 /// \brief Obtain detailed information about scope namespace.
200 ///
201 /// \return Detailed information about scope namespace.
202 virtual std::string fullname_with_scope();
203
204 /// \brief Obtain the unique name of this AnfNode.
205 ///
206 /// \return The unique name of this AnfNode.
207 std::string UniqueName();
208
209 /// \brief Obtain the display information of this AnfNode.
210 ///
211 /// \param[in] recursive_level Recursion level when displayed.
212 /// \return Information to be displayed.
213 virtual std::string DebugString(int recursive_level = 1) const;
214
215 /// \brief Obtain the display information of this AnfNode.
216 ///
217 /// \param[in] recursive Whether to display AnfNode recursively.
218 /// \return Information to be displayed.
219 virtual std::string DebugString(bool recursive) const;
220
221 std::string ToString() const override;
222
223 void dump() const override;
224
225 /// \brief Obtain the unique id of the debug information of this AnfNode.
226 ///
227 /// \return Unique id.
228 std::string UniqueId();
229
230 /// \brief Obtain the unique id through copied traced information.
231 ///
232 /// \return Unique id.
233 std::string UniqueIdThroughCopy();
234
235 /// \brief Determine whether two AnfNodes are the same.
236 ///
237 /// \param[in] other Another ANfNode.
238 /// \return True if the same, otherwise False.
239 virtual bool operator==(const AnfNode &other) const;
240
241 /// \brief Obtain the display information of this AnfNode.
242 ///
243 /// \param[in] os Output stream.
244 /// \param[in] node AnfNode to be displayed.
245 /// \return Output stream.
246 friend std::ostream &operator<<(std::ostream &os, const AnfNode &node);
247
248 /// \brief Check if there is an interpret node.
249 ///
250 /// \return True if there is an interpret node, otherwise false.
251 bool interpret() const;
252
253 /// \brief Whether to use interpretation
254 ///
255 /// \param[in] interpret Boolean.
256 void set_interpret(const bool &interpret);
257
258 /// \brief Check if there is an interpret node related to the unsupported internal type.
259 ///
260 /// \return True if there is an interpret node related to the unsupported internal type, otherwise false.
261 bool interpret_internal_type();
262
263 /// \brief Whether there is an interpret node with unsupported internal type.
264 ///
265 /// \param[in] interpret_internal_type Boolean.
266 void set_interpret_internal_type(const bool &interpret_internal_type);
267
268 SeenNum seen_{0};
269 SeenNum extra_seen_{0};
270
271 protected:
272 // Hold a weak ref to Graph as Graph also hold ref to AnfNode.
273 // Otherwise, func_graph_ and AnfNode will make a reference cycle.
274 FuncGraphWeakPtr func_graph_;
275 AbstractBasePtr abstract_;
276 NodeDebugInfoPtr debug_info_;
277 std::string fullname_with_scope_;
278
279 private:
280 static constexpr size_t kInterpret = 0;
281 static constexpr size_t kInterpretInternalType = 1;
282 static constexpr size_t kNumInterpretFlags = 2;
283 static constexpr auto kKernelInfoKey = "kernel_info";
284
285 ScopePtr scope_;
286 std::bitset<kNumInterpretFlags> interpret_flags_;
287 };
288
289 // CNode represents the complex node with a set of arguments.
290 // Fields:
291 // inputs_: represents all of the inputs for this CNode.
292 // weak_inputs_: represents all of the weak inputs for this CNode.
293 // Using input(i) to get the index i input.
294 // Using inputs() to get all the inputs as a vector.
295 // Using add_input(input) to append a new input for a CNode.
296 // Using set_input(i, input) to change some input of these inputs.
297 // Using set_inputs(inputs) to refresh all of the inputs of a CNode.
298 // func_graph_as_var: used in opt pattern matching to match a real FuncGraph.
299 // stop_gradient: a flag used to stop gradient.
300 // Using stop_gradient() to get this flag, mainly used in ad.
301 // Using set_stop_gradient() to set this flag.
302 class MS_CORE_API CNode final : public AnfNode, public EffectInfoHolder {
303 public:
304 /// \brief Constructor.
305 ///
306 /// \param[in] inputs Input nodes of this Cnode.
307 /// \param[in] func_graph The FuncGraph to which this CNode belongs.
308 CNode(const AnfNodePtrList &inputs, const FuncGraphPtr &func_graph);
309
310 /// \brief Constructor.
311 ///
312 /// \param[in] inputs Input nodes of this Cnode.
313 /// \param[in] func_graph_as_var The FuncGraph of type VarPtr to which this CNode belongs,
314 CNode(const AnfNodePtrList &inputs, const VarPtr &func_graph_as_var);
315
316 /// \brief Constructor.
317 ///
318 /// \param[in] weak_inputs Input nodes of this Cnode.
319 /// \param[in] func_graph The FuncGraph to which this CNode belongs.
320 CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph);
321
322 /// \brief Constructor.
323 ///
324 /// \param[in] weak_inputs Input nodes of this Cnode.
325 /// \param[in] func_graph The FuncGraph to which this CNode belongs.
326 CNode(const AnfNodeWeakPtrList &weak_inputs, const FuncGraphPtr &func_graph);
327
328 /// \brief Constructor.
329 ///
330 /// \param[in] weak_inputs Input nodes of this Cnode.
331 /// \param[in] func_graph The FuncGraph to which this CNode belongs.
332 /// \param[in] debug_info The debug info to be used for this CNode.
333 CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info);
334
335 /// \brief Destructor.
336 ~CNode() override = default;
337 MS_DECLARE_PARENT(CNode, AnfNode);
338
339 void accept(AnfIrVisitor *v) override;
340
341 /// \brief Check whether this cnode has the same primitive value as the first input.
342 ///
343 /// \return True if they have the same primitive value, otherwise false.
344 bool IsApply(const PrimitivePtr &value) const;
345
346 /// \brief Obtain the size of input nodes of this CNode.
347 ///
348 /// \return Size of input nodes.
349 const size_t size() const;
350
351 /// \brief Check if this CNode has no input node.
352 ///
353 /// \return if input nodes is empty.
354 const bool empty() const;
355
356 /// \brief Get the input node of the given index.
357 /// The weak_input() is recommended.
358 ///
359 /// \param[in] i The given index.
360 /// \return The input node of the given index.
361 const AnfNodePtr input(size_t i) const;
362
363 /// \brief Get the input node of the given index.
364 ///
365 /// \param[in] i The given index.
366 /// \return The input node of the given index.
367 const AnfNodeWeakPtr &weak_input(size_t i) const;
368
369 /// \brief Get the input nodes.
370 /// The weak_inputs() is recommended.
371 ///
372 /// \return The input nodes of this CNode.
373 const AnfNodePtrList &inputs();
374
375 /// \brief Get the input nodes.
376 ///
377 /// \return The input nodes of this CNode.
378 const AnfNodeWeakPtrList &weak_inputs() const;
379
380 /// \brief Add the input node to this CNode.
381 ///
382 /// \param[in] input Node.
383 void add_input(const AnfNodePtr &input);
384
385 /// \brief Set the input node of the given index.
386 ///
387 /// \param[in] i The given index.
388 /// \param[in] input Node.
389 void set_input(size_t i, const AnfNodePtr &new_input);
390
391 /// \brief Set the input nodes for this CNode.
392 ///
393 /// \param[in] inputs Input nodes.
394 void set_inputs(const AnfNodePtrList &inputs);
395
396 /// \brief Set the input nodes for this CNode.
397 ///
398 /// \param[in] weak_inputs Input nodes.
399 void set_weak_inputs(const AnfNodeWeakPtrList &weak_inputs);
400
401 // output_value store cnode value and id in pynative mode.
402 using OutputValue = std::pair<ValueNodePtr, std::string>;
403
404 /// \brief Record the cnode value and id to output_value_.
405 ///
406 /// \param[in] forward The cnode value.
407 /// \param[in] id The id.
408 void set_forward(const ValueNodePtr &forward, const std::string &id);
409
410 /// \brief Get the record of output value of this CNode.
411 ///
412 /// \return The output value of this CNode.
413 const OutputValue &forward() const;
414
415 /// \brief Check if stop_gradient is set.
416 ///
417 /// \return True if stop_gradient is set, otherwise false.
418 bool stop_gradient() const;
419
420 /// \brief Set stop_gradient.
421 ///
422 /// \param[in] stop_gradient Boolean.
423 void set_stop_gradient(bool stop_gradient);
424
425 std::string fullname_with_scope() override;
426
427 /// \brief Set fullname_with_scope for this CNode.
428 ///
429 /// \param[in] full_name The fullname_with_scope.
430 void set_fullname_with_scope(const std::string full_name);
431
432 std::string DebugString(int recursive_level = 1) const override;
433 std::string DebugString(bool recursive) const override;
434
435 /// \brief Set in_forward_flag for this CNode.
436 ///
437 /// \param[in] flag Boolean.
438 void set_in_forward_flag(bool flag);
439 /// \brief Check if in_forward_flag is set.
440 ///
441 /// \return True if in_forward_flag is set, otherwise false.
442 bool in_forward_flag() const;
443
444 /// \brief Check if the primitive of this CNode is load.
445 ///
446 /// \param[in] is_load Boolean.
447 void set_load_flag(bool is_load);
448 /// \brief Check if is_load_ is set.
449 ///
450 /// \return True if is_load_ is set, otherwise false.
451 bool get_load_flag() const;
452
453 /// \brief Get func_graph_as_var of this CNode.
454 ///
455 /// \return func_graph_as_var.
456 VarPtr func_graph_as_var() const;
457
458 /// \brief Get all attributes of this CNode.
459 ///
460 /// \return Attributes of this CNode.
461 const mindspore::HashMap<std::string, ValuePtr> &attrs() const;
462 void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs);
463
464 /// \brief Add a new attribute to this CNode.
465 ///
466 /// \param[in] name The name of the new attribute.
467 /// \param[in] attr The value of the new attribute.
468 void AddAttr(const std::string &name, const ValuePtr &attr);
469
470 /// \brief Erase the attribute with the given name.
471 ///
472 /// \param[in] name The name of attribute.
473 void EraseAttr(const std::string &name);
474
475 /// \brief Get the attribute with the given name.
476 ///
477 /// \param[in] name The name of attribute.
478 /// \return Attribute.
479 ValuePtr GetAttr(const std::string &name) const;
480
481 /// \brief Check whether this CNode has an attribute with the given name.
482 ///
483 /// \param[in] name The name of attribute.
484 /// \return Boolean.
485 bool HasAttr(const std::string &name) const;
486
487 /// \brief Get the number of input tensors.
488 ///
489 /// \return The number of input tensors.
490 ssize_t input_tensor_num() const;
491
492 /// \brief Get the primal attributes of this CNode.
493 ///
494 /// \return The primal attributes.
495 const mindspore::HashMap<std::string, ValuePtr> &primal_attrs() const;
496
497 /// \brief Set the primal attributes of this CNode.
498 ///
499 /// \param[in] attrs The primal attributes.
500 void set_primal_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs);
501
502 /// \brief Add the primal attribute to this CNode.
503 ///
504 /// \param[in] name The name of the attribute.
505 /// \param[in] attr The attribute.
506 void AddPrimalAttr(const std::string &name, const ValuePtr &attr);
507
508 /// \brief Erase the primal attribute with the given name.
509 ///
510 /// \param[in] name The name of the attribute.
511 void ErasePrimalAttr(const std::string &name);
512
513 /// \brief Get the primal attribute with the given name.
514 ///
515 /// \param[in] name The name of the attribute.
516 /// \return The primal attribute with the given name.
517 ValuePtr GetPrimalAttr(const std::string &name) const;
518
519 /// \brief Check whether this CNode has an attribute with the given name.
520 ///
521 /// \param[in] name The name of the attribute.
522 /// \return True if it exists, otherwise false.
523 bool HasPrimalAttr(const std::string &name) const;
524
525 /// \brief Get primal debug information.
526 ///
527 /// \return The primal debug information.
528 NodeDebugInfoSet primal_debug_infos() const;
529
530 /// \brief Set primal debug information.
531 ///
532 /// \param[in] debug_infos Debug information of this CNode.
533 void set_primal_debug_infos(const NodeDebugInfoSet &debug_infos);
534
535 /// \brief Add a primal debug information.
536 ///
537 /// \param[in] debug_info A debug information.
538 void AddPrimalDebugInfo(const NodeDebugInfoPtr &debug_info);
539
540 void CloneCNodeInfo(const CNodePtr &node);
541
542 /// \brief Set the number of input tensors.
543 ///
544 /// \param[in] The number of input tensors.
545 void set_input_tensor_num(ssize_t input_tensor_num);
546
547 /// \brief Is effect have been handled.
548 ///
549 /// \return True if effect have been handled, otherwise false.
550 bool IsEffectHandled() const;
551
552 /// \brief Set effect handled or not.
553 ///
554 /// \param[in] handled Boolean.
555 void SetEffectHandled(bool handled);
556
557 /// \brief Get the debug infos of fused nodes.
558 ///
559 /// \return A vector of debug infos.
560 NodeDebugInfoSet fused_debug_infos() const;
561
562 /// \brief Set the debug infos for CNode.
563 ///
564 /// \param fused_debug_infos The debug infos to be set.
565 void set_fused_debug_infos(const NodeDebugInfoSet &fused_debug_infos);
566
567 /// \brief Add a node's debug info or fused debug info.
568 ///
569 /// \param node An anf node.
570 void AddFusedDebugInfo(const AnfNodePtr &node);
571
572 /// \brief Add a vector of nodes' debug info or fused debug info.
573 ///
574 /// \param nodes A vector of anf nodes.
575 void AddFusedDebugInfoList(const AnfNodePtrList &nodes);
576
577 /// \brief Add a node debug info.
578 ///
579 /// \param debug_info A node debug info of an anf node.
580 void AddFusedDebugInfo(const NodeDebugInfoPtr &debug_info);
581
582 /// \brief Add a list of node debug infos.
583 ///
584 /// \param debug_infos A node debug info of an anf node.
585 void AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos);
586
587 /// \brief Check whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
588 ///
589 /// \return True if contains, otherwise false.
590 bool has_side_effect_node() const;
591
592 /// \brief Set whether contains a input or indirect input, which is Depend CNode with isolated side-effect node.
593 ///
594 /// \param[in] has_side_effect_node Boolean.
595 void set_has_side_effect_node(bool has_side_effect_node);
596
597 /// \brief Set the debugging information of this AnfNode.
598 ///
599 /// \return New debugging information.
600 void set_debug_info(const NodeDebugInfoPtr &debug_info) override;
601
602 private:
603 void Init();
604 void CheckCNodeWeakInput();
605
606 static constexpr size_t kStopGradient = 0;
607 static constexpr size_t kInForwardFlag = 1;
608 static constexpr size_t kEffectHandled = 2;
609 static constexpr size_t kIsLoad = 3;
610 static constexpr size_t kNumFlags = 4;
611 static constexpr auto kFuncGraphVarKey = "fg_var";
612 static constexpr auto kOutputValueKey = "out_value";
613
614 AnfNodeWeakPtrList weak_inputs_;
615 // Most of the time it's empty.
616 // It's set after call CNode::inputs() for compatibility with earlier version.
617 AnfNodePtrList inputs_;
618 bool inputs_bound_{false};
619
620 ssize_t input_tensor_num_ = -1;
621 std::bitset<kNumFlags> flags_;
622
623 mindspore::HashMap<std::string, ValuePtr> attrs_;
624 mindspore::HashMap<std::string, ValuePtr> primal_attrs_;
625 NodeDebugInfoSet primal_debug_infos_;
626 NodeDebugInfoSet fused_debug_infos_;
627
628 // If the inputs or their inputs contain Depend CNode with isolated side-effect node.
629 bool has_side_effect_node_{false};
630 };
631
632 // ANode represents the atomic node. It's derived Parameter and ValueNode.
633 class MS_CORE_API ANode : public AnfNode {
634 public:
ANode()635 ANode() : AnfNode(nullptr) {}
636
637 /// \brief Constructor.
638 ///
639 /// \param[in] func_graph The FuncGraph to which this ANode belongs.
ANode(const FuncGraphPtr & func_graph)640 explicit ANode(const FuncGraphPtr &func_graph) : AnfNode(func_graph) {}
641
642 /// \brief Constructor.
643 ///
644 /// \param[in] func_graph The FuncGraph to which this ANode belongs.
645 /// \param[in] debug_info The debug info to be used for this ANode.
ANode(const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)646 ANode(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info) : AnfNode(func_graph, std::move(debug_info)) {}
647
648 /// \brief Destructor.
649 virtual ~ANode() = default;
650
651 MS_DECLARE_PARENT(ANode, AnfNode);
652 };
653
654 // Parameter represents the parameter inputs of a function. They have no value.
655 // Attributes:
656 // default_param_value_: used to hold the inputting tensor of the model.
657 class MS_CORE_API Parameter final : public ANode {
658 public:
659 explicit Parameter(const FuncGraphPtr &func_graph);
660
661 Parameter(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info);
662
663 /// \brief Destructor.
664 ~Parameter() override = default;
665 MS_DECLARE_PARENT(Parameter, ANode);
666
667 void accept(AnfIrVisitor *v) override;
668 std::string DebugString(int recursive_level = 1) const override;
669
670 /// \brief Get the name of this Parameter.
671 ///
672 /// \return The name.
673 std::string name() const;
674
675 /// \brief Set the name of this Parameter.
676 ///
677 /// \param[in] The name.
678 void set_name(const std::string &name);
679
680 std::string fullname_with_scope() override;
681
682 /// \brief Check if there is a default parameter.
683 ///
684 /// \return True if this Parameter has a default parameter, otherwise false.
685 bool has_default() const;
686
687 /// \brief Set the default parameter.
688 ///
689 /// \param[in] param The default parameter.
690 void set_default_param(const ValuePtr ¶m);
691
692 /// \brief Get the default parameter.
693 ///
694 /// \return The default parameter.
695 const ValuePtr &default_param() const;
696
697 /// \brief Get the parameter information.
698 ///
699 /// \return The parameter information.
700 ParamInfoPtr param_info() const;
701
702 /// \brief Increase used_graph_count.
703 void IncreaseUsedGraphCount();
704 /// \brief Decrease used_graph_count.
705 void DecreaseUsedGraphCount();
706 /// \brief Get used_graph_count.
707 ///
708 /// \return used_graph_count.
709 int used_graph_count() const;
710
711 bool is_top_graph_param() const;
712 void set_is_top_graph_param(bool flag);
713
714 bool operator==(const AnfNode &other) const override;
715
716 /// \brief This parameter is not used in graph with id.
717 ///
718 /// \param[in] graph_id The graph id.
719 void SetNotUsedByRealKernelInGraph(uint32_t graph_id);
720
721 /// \brief Check if this Parameter is used in graph with id.
722 ///
723 /// \param[in] graph_id True if used, otherwise false.
724 bool IsUsedByRealKernelInGraph(uint32_t graph_id) const;
725
726 /// \brief Set whether this Parameter has a dynamic shape.
727 ///
728 /// \param[in] flag Boolean.
729 void set_has_dynamic_shape(bool flag);
730
731 /// \brief Check whether this Parameter has a dynamic shape.
732 ///
733 /// \return True if this Parameter has a dynamic shape, otherwise false.
734 bool has_dynamic_shape() const;
735
736 /// \brief Set whether this Parameter is dynamic len.
737 ///
738 /// \param[in] flag Boolean.
739 void set_dynamic_len(bool flag);
740
741 /// \brief Check whether this Parameter is dynamic len.
742 ///
743 /// \return True if this Parameter is dynamic len, otherwise false.
744 bool dynamic_len() const;
745
746 /// \brief Set groups attr in FRACTAL_Z format.
747 ///
748 /// \param[in] fracz_group Groups attr in FRACTAL_Z format.
749 void set_fracz_group(int64_t fracz_group);
750
751 /// \brief Get groups attr in FRACTAL_Z format.
752 ///
753 /// \return Groups attr in FRACTAL_Z format.
754 int64_t fracz_group() const;
755
756 /// \brief Set input_size attr in FracNZ_RNN or ND_RNN_Bias format.
757 ///
758 /// \param[in] input_size input_size attr in FracNZ_RNN or ND_RNN_Bias format.
759 void set_input_size(int64_t input_size);
760
761 /// \brief Get input_size attr in FracNZ_RNN or ND_RNN_Bias format.
762 ///
763 /// \return input_size attr in FracNZ_RNN or ND_RNN_Bias format.
764 int64_t input_size() const;
765
766 /// \brief Set hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
767 ///
768 /// \param[in] hidden_size hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
769 void set_hidden_size(int64_t hidden_size);
770
771 /// \brief Get hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
772 ///
773 /// \return hidden_size attr in FracNZ_RNN or ND_RNN_Bias format.
774 int64_t hidden_size() const;
775
776 /// \brief Set the debugging information of this AnfNode.
777 ///
778 /// \return New debugging information.
779 void set_debug_info(const NodeDebugInfoPtr &debug_info) override;
780
781 private:
782 void Init();
783 struct FormatAttr {
784 int64_t fracz_group = 1;
785 int64_t input_size = 0;
786 int64_t hidden_size = 0;
787 };
788 std::string name_;
789 ValuePtr default_param_;
790 // Some attrs used in special format.
791 FormatAttr format_attrs_;
792 std::set<uint32_t> not_used_in_graphs_;
793 int used_graph_count_ = 0;
794 bool has_default_ = false;
795 bool has_dynamic_shape_ = false;
796 // Dynamic len is a flag indicating whether the parameter is dynamic sequence.
797 bool is_dynamic_len_ = false;
798 bool is_top_graph_param_ = false;
799 };
800 using ParameterPtr = std::shared_ptr<Parameter>;
801 using ParameterWeakPtr = std::weak_ptr<Parameter>;
802
803 // Value is used to represent the atomic expression mentioned in BNF.
804 // It mainly be stored in ValueNode. Value and ValueNode is related definition.
805 class MS_CORE_API Value : public Base {
806 public:
807 /// \brief Default constructor.
808 Value() = default;
809
810 /// \brief Constructor of Value.
811 ///
812 /// \param[in] t The type of this Value.
813 explicit Value(const TypePtr t);
814
815 /// \brief Constructor of Value.
816 ///
817 /// \param[in] other Another Value.
818 Value(const Value &other);
819
820 /// \brief Destructor.
821 ~Value() override = default;
822 MS_DECLARE_PARENT(Value, Base)
823
824 /// \brief Get the type of this Value.
825 ///
826 /// \return The type.
827 TypePtr type() const;
828
829 /// \brief Get the abstract value of Value.
830 ///
831 /// \return Abstract value of Value.
832 virtual abstract::AbstractBasePtr ToAbstract();
833
834 /// \brief Check whether the input is the current Value object.
835 ///
836 /// \param[in] rhs The Value object to be compared.
837 /// \return Whether the input is the current Value object.
838 virtual bool operator==(const Value &rhs) const = 0;
839
840 /// \brief Check whether the input is the current Value object.
841 ///
842 /// \param[in] other The Value object to be compared.
843 /// \return Whether the input is the current Value object.
844 Value &operator=(const Value &other);
845
846 virtual bool ContainsValueAny() const;
847
848 protected:
849 TypePtr type_{nullptr};
850 };
851
852 // ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
853 // does not belong to any particular function graph.
854 class MS_CORE_API ValueNode final : public ANode {
855 public:
856 /// \brief Constructor of ValueNode.
857 ///
858 /// \param[in] value The value of this ValueNode.
859 explicit ValueNode(const ValuePtr &value);
860
861 /// \brief Constructor of ValueNode.
862 ///
863 /// \param[in] value The value of this ValueNode.
864 /// \param[in] debug_info The debug info to be used for this ValueNode.
865 ValueNode(const ValuePtr &value, NodeDebugInfoPtr &&debug_info);
866
867 /// \brief Destructor.
868 ~ValueNode() override = default;
869 MS_DECLARE_PARENT(ValueNode, ANode);
870
871 void set_func_graph(const FuncGraphPtr &) override;
872
873 void accept(AnfIrVisitor *v) override;
874
875 /// \brief Set the value of this ValueNode.
876 ///
877 /// \param[in] value The value.
878 void set_value(const ValuePtr &value);
879
880 /// \brief Get the value of this ValueNode.
881 ///
882 /// \return The value.
883 const ValuePtr &value() const;
884
885 std::string fullname_with_scope() override;
886
887 /// \brief Set whether this ValueNode has a new value.
888 ///
889 /// \param[in] flag Whether this ValueNode has a new value.
890 void set_has_new_value(bool flag);
891
892 /// \brief Check whether this ValueNode has a new value.
893 ///
894 /// \return Whether this ValueNode has a new value.
895 bool has_new_value() const;
896
897 /// \brief Get the count of graphs using this ValueNode.
898 ///
899 /// \return The count of graphs using this ValueNode.
900 size_t used_graph_count() const;
901
902 /// \brief Set the count of groups using this ValueNode.
903 ///
904 /// \param[in] group The count of groups using this ValueNode.
905 void set_fracz_group(int64_t group);
906
907 /// \brief Get groups attr in FRACTAL_Z format.
908 ///
909 /// \return Groups attr in FRACTAL_Z format.
910 int64_t fracz_group() const;
911
912 /// \brief Set the count of graphs using this ValueNode.
913 ///
914 /// \param[in] used_graph_count The count of graphs using this ValueNode.
915 void set_used_graph_count(size_t used_graph_count);
916
917 std::string ToString() const override;
918 std::string DebugString(int recursive_level = 1) const override;
919 std::string DebugString(bool recursive) const override;
920
921 bool operator==(const AnfNode &other) const override;
922 friend std::ostream &operator<<(std::ostream &os, const ValueNodePtr &node) {
923 MS_EXCEPTION_IF_NULL(node);
924 os << node->ToString();
925 return os;
926 }
927
928 /// \brief Set the debugging information of this AnfNode.
929 ///
930 /// \return New debugging information.
931 void set_debug_info(const NodeDebugInfoPtr &debug_info) override;
932
933 private:
934 void Init();
935 struct FormatAttr {
936 int64_t fracz_group = 1;
937 int64_t input_size = 0;
938 int64_t hidden_size = 0;
939 };
940 FormatAttr format_attr_;
941 ValuePtr value_;
942 size_t used_graph_count_{0};
943 bool has_new_value_ = false;
944 };
945
946 template <typename T>
947 struct ImmTraits {};
948
949 #define IMM_TRAITS(typeimm, prototype) \
950 template <> \
951 struct ImmTraits<prototype> { \
952 using type = typeimm; \
953 };
954
MakeValue(const ValuePtr & value)955 inline ValuePtr MakeValue(const ValuePtr &value) { return value; }
956
957 template <typename S, typename U = typename ImmTraits<S>::type::element_type>
MakeValue(S v)958 inline ValuePtr MakeValue(S v) {
959 return std::make_shared<U>(v);
960 }
961
962 template <typename S, typename U = typename ImmTraits<S>::type>
GetValue(const ValuePtr & value)963 static S GetValue(const ValuePtr &value) {
964 MS_EXCEPTION_IF_NULL(value);
965 auto imm = value->cast_ptr<typename U::element_type>();
966 if (imm == nullptr) {
967 MS_LOG(INTERNAL_EXCEPTION) << "Cast failed, original value: " << value->ToString()
968 << ", type: " << value->type_name();
969 }
970 return imm->value();
971 }
972
973 template <typename S,
974 typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
975 S>::type * = nullptr>
GetValue(const ValuePtr & value)976 static S GetValue(const ValuePtr &value) {
977 MS_EXCEPTION_IF_NULL(value);
978 S v = value->cast<S>();
979 if (v == nullptr) {
980 MS_LOG(INTERNAL_EXCEPTION) << "Cast failed, original value: " << value->ToString()
981 << ", type: " << value->type_name();
982 }
983 return v;
984 }
985
986 MS_CORE_API std::string GetCNodeFuncName(const CNodePtr &cnode);
987
988 // Used to get FuncGraphPtr from a cnode first input
989 MS_CORE_API FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node);
990
991 // Used to check whether an AnfNode is a cnode with a kind of Primitive as first input.
992 MS_CORE_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
993
994 // Used to check whether an AnfNode is a cnode with a kind of Primitive as first input.
995 // If the Primitive is DoSignature, get the real Primitive firstly.
996 MS_CORE_API bool IsPrimitiveCNodeWithoutDoSignature(const AnfNodePtr &node, const PrimitivePtr &value);
997
998 // Used to get PrimitivePtr from a cnode first input
999 MS_CORE_API PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
1000
1001 // Return the function Primitive if DoSignaturePrimitive,
1002 // otherwise return the Primitive directly.
1003 MS_CORE_API PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node);
1004 // Check the first input of CNode.
1005 // Return the function Primitive if DoSignaturePrimitive,
1006 // otherwise return the Primitive directly.
1007 MS_CORE_API PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node);
1008
1009 // Return the function value if DoSignaturePrimitive,
1010 // otherwise return the value directly.
1011 MS_CORE_API ValuePtr GetValueWithoutDoSignature(const ValuePtr &value);
1012 // Return the function value if DoSignaturePrimitive,
1013 // otherwise return the value directly.
1014 MS_CORE_API ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node);
1015 // Check the first input of CNode.
1016 // Return the function value if DoSignaturePrimitive,
1017 // otherwise return the value directly.
1018 MS_CORE_API ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node);
1019
1020 /// \brief Used to check whether the given node is a ValueNode with some Primitive value.
1021 ///
1022 /// \param[in] node The input node.
1023 /// \param[in] value Primitive value.
1024 /// \return Whether the given node is a ValueNode with some Primitive value.
1025 MS_CORE_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value);
1026
1027 // Check whether the given node is a ValueNode belonging to a primitive set.
1028 MS_CORE_API bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set);
1029
1030 /// \brief Used to check whether the given node is a CNode belonging to a primitive set.
1031 ///
1032 /// \param[in] node The input node.
1033 /// \param[in] prim_set Primitive set.
1034 /// \return Whether the given node is a CNode belonging to a primitive set.
1035 MS_CORE_API bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
1036
1037 // Check whether two primitives are same.
1038 MS_CORE_API bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2);
1039
1040 // Get number of AbstractMonad
1041 MS_CORE_API size_t GetAbstractMonadNum(const AbstractBasePtrList &args);
1042
1043 // Check whether the given node has monad abstract.
1044 MS_CORE_API bool HasAbstractMonad(const AnfNodePtr &node);
1045
1046 // Check whether the given node has U monad abstract.
1047 MS_CORE_API bool HasAbstractUMonad(const AnfNodePtr &node);
1048
1049 // Check whether the given node has IO monad abstract.
1050 MS_CORE_API bool HasAbstractIOMonad(const AnfNodePtr &node);
1051
1052 // Gets primitive attribute value as a bool flag.
1053 MS_CORE_API bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr);
1054
1055 // Gets effect info from a primitive by its attributes.
1056 MS_CORE_API EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim);
1057
1058 // Check if monad state is equivalent for the connected two nodes, not strict but more faster.
1059 MS_CORE_API bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner);
1060
1061 // Check if the node is DeadNode.
1062 MS_CORE_API bool IsDeadNode(const AnfNodePtr &node);
1063
1064 // Check if the node is PolyNode.
1065 MS_CORE_API bool IsPolyNode(const AnfNodePtr &node);
1066
1067 // Used to check whether a ValueNode has some kind of value.
1068 template <typename T>
IsValueNode(const AnfNodePtr & node)1069 bool IsValueNode(const AnfNodePtr &node) {
1070 auto value_node = dyn_cast_ptr<ValueNode>(node);
1071 if (value_node == nullptr) {
1072 return false;
1073 }
1074 const auto &value = value_node->value();
1075 return (value != nullptr) && (value->isa<T>());
1076 }
1077
GetValueNode(const AnfNodePtr & node)1078 inline ValuePtr GetValueNode(const AnfNodePtr &node) {
1079 auto value_node = dyn_cast_ptr<ValueNode>(node);
1080 return (value_node == nullptr) ? nullptr : value_node->value();
1081 }
1082
GetValuePtr(const AnfNodePtr & node)1083 inline Value *GetValuePtr(const AnfNodePtr &node) {
1084 auto value_node = dyn_cast_ptr<ValueNode>(node);
1085 return (value_node == nullptr) ? nullptr : value_node->value().get();
1086 }
1087
1088 template <typename S,
1089 typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
1090 S>::type * = nullptr>
GetValueNode(const AnfNodePtr & node)1091 inline S GetValueNode(const AnfNodePtr &node) {
1092 auto value = GetValuePtr(node);
1093 return (value == nullptr) ? nullptr : value->cast<S>();
1094 }
1095
1096 template <typename S, typename std::enable_if<std::is_base_of<Value, S>::value, S>::type * = nullptr>
GetValuePtr(const AnfNodePtr & node)1097 inline S *GetValuePtr(const AnfNodePtr &node) {
1098 auto value_node = dyn_cast_ptr<ValueNode>(node);
1099 if (value_node == nullptr) {
1100 return nullptr;
1101 }
1102 const auto &value = value_node->value();
1103 return (value == nullptr) ? nullptr : value->cast_ptr<S>();
1104 }
1105
1106 MS_CORE_API SeenNum NewSeenGeneration();
1107
1108 namespace id_generator {
1109 MS_CORE_API std::string get_id(const std::string &front_string);
1110 MS_CORE_API void reset_id();
1111 MS_CORE_API void reset_id_with_offset();
1112 } // namespace id_generator
1113 using TaggedNodeMap = mindspore::HashMap<AnfNodePtr, size_t>;
1114 using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
1115 MS_CORE_API std::string GetCNodeTarget(const AnfNodePtr &node);
1116 std::string GetOriginNodeTarget(const AnfNodePtr &node);
1117 MS_CORE_API bool ContainMultiTarget(const AnfNodePtrList &nodes);
1118 struct GraphSegment {
GraphSegmentGraphSegment1119 GraphSegment(const AnfNodePtrList &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
AddPreSegmentGraphSegment1120 void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); }
1121 AnfNodePtrList nodes_;
1122 std::set<std::shared_ptr<GraphSegment>> pre_segments_;
1123 bool is_cut_{false};
1124 uint32_t graph_id_{0};
1125 };
1126 using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
1127
1128 constexpr auto kElementsUseFlagsKey = "elements_use_flags";
GetSequenceNodeElementsUseFlags(const AnfNodePtr & node)1129 inline std::shared_ptr<std::vector<bool>> GetSequenceNodeElementsUseFlags(const AnfNodePtr &node) {
1130 MS_EXCEPTION_IF_NULL(node);
1131 return node->template user_data<std::vector<bool>>(kElementsUseFlagsKey);
1132 }
1133
SetSequenceNodeElementsUseFlags(const AnfNodePtr & node,const std::shared_ptr<std::vector<bool>> & flags)1134 inline void SetSequenceNodeElementsUseFlags(const AnfNodePtr &node, const std::shared_ptr<std::vector<bool>> &flags) {
1135 MS_EXCEPTION_IF_NULL(node);
1136 node->set_user_data(kElementsUseFlagsKey, flags);
1137 }
1138
1139 // Set the sequence nodes' elements use flags to 'new_flag' at specific 'index' position.
1140 MS_CORE_API void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, std::size_t index, bool new_flag);
1141 // Set the sequence nodes' elements use flags all to 'new_flag'.
1142 MS_CORE_API void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag);
1143 // Set the sequence nodes' elements use flags all to 'new_flag' recursively.
1144 MS_CORE_API void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag);
1145 } // namespace mindspore
1146 #endif // MINDSPORE_CORE_IR_ANF_H_
1147