• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_H_
20 #define MINDSPORE_CORE_IR_FUNC_GRAPH_H_
21 
22 #include <set>
23 #include <map>
24 #include <string>
25 #include <vector>
26 #include <list>
27 #include <unordered_map>
28 #include <memory>
29 #include <functional>
30 #include <utility>
31 
32 #include "utils/hash_map.h"
33 #include "utils/hash_set.h"
34 #include "utils/ordered_set.h"
35 #include "utils/ordered_map.h"
36 #include "mindapi/base/macros.h"
37 #include "base/base_ref.h"
38 #include "base/effect_info.h"
39 #include "ir/anf.h"
40 #include "ir/manager.h"
41 #include "ir/func_graph_transform.h"
42 #include "ir/func_graph_base.h"
43 #include "abstract/abstract_value.h"
44 #include "mindspore/core/symbolic_shape/symbol_engine.h"
45 
46 namespace mindspore {
47 using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
48 using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
49 
50 struct CNodeIndexHasher {
operatorCNodeIndexHasher51   std::size_t operator()(const CNodeIndexPairPtr pair) const {
52     MS_EXCEPTION_IF_NULL(pair);
53     MS_EXCEPTION_IF_NULL(pair->first);
54     return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
55   }
56 };
57 
58 struct CNodeIndexEqual {
operatorCNodeIndexEqual59   bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
60     if (lhs == nullptr || rhs == nullptr) {
61       return false;
62     }
63     if (lhs == rhs) {
64       return true;
65     }
66     if (lhs->first != rhs->first) {
67       return false;
68     }
69     if (lhs->second != rhs->second) {
70       return false;
71     }
72     return true;
73   }
74 };
75 
76 template <typename KeyT, class CounterHash = std::hash<KeyT>, class CounterEqual = std::equal_to<KeyT>>
77 using CounterOrderedMap = OrderedMap<KeyT, int, CounterHash, CounterEqual>;
78 using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
79 using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
80 
81 using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
82 
83 const char FUNC_GRAPH_FLAG_IGNORE_VALUE[] = "ignore_value";
84 const char FUNC_GRAPH_FLAG_VMAP_TRANSFORMED[] = "vmap_transformed";
85 const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
86 const char FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP[] = "primal_of_bprop";
87 const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop";
88 const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline";
89 const char FUNC_GRAPH_FLAG_CELL_REUSE[] = "cell_reuse";
90 const char FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER[] = "lazy_inline_order";
91 const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
92 const char FUNC_GRAPH_FLAG_CORE[] = "core";
93 const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph";
94 const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
95 const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
96 const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
97 const char FUNC_GRAPH_RECOMPUTE_K_GRAPH[] = "recompute_k_graph";
98 const char FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH[] = "recompute_grad_graph";
99 const char FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH[] = "not_recompute_k_graph";
100 const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline";
101 const char FUNC_GRAPH_FLAG_DUMP[] = "dump";
102 const char FUNC_GRAPH_FLAG_DYNAMIC_SHAPE[] = "dynamic_shape";
103 const char FUNC_GRAPH_FLAG_NO_RECURSIVE[] = "no_recursive";
104 const char FUNC_GRAPH_FLAG_ARGS_NO_EXPAND[] = "args_no_expand";
105 const char FUNC_GRAPH_FLAG_PROXY_GRAPH[] = "proxy_graph";
106 const char FUNC_GRAPH_FLAG_NO_CHILD_GRAPH[] = "no_child_graph";
107 
108 const char kFuncGraphFlagUndetermined[] = "undeterminate";
109 const char kFuncGraphFlagBackPropEntry[] = "back_prop_entry";
110 const char kFuncGraphFlagReAutoMonad[] = "re_auto_monad";
111 const char kFuncGraphFlagRecursive[] = "recursive";
112 const char kFuncGraphFlagMetaFuncGraphBprop[] = "meta_fg_bprop";
113 const char kFuncGraphFlagAddedForwardU[] = "added_forward_u";
114 
115 class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
116  public:
117   using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
118 
119   FuncGraph();
120   explicit FuncGraph(GraphDebugInfoPtr &&debug_info);
121   ~FuncGraph();
122   MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
123 
124   void DoBreakLoop() override;
125 
126   // Get the graph's abstract.
127   abstract::AbstractFunctionPtr abstract();
128   abstract::AbstractBasePtr ToAbstract() override;
129 
130   // get function graph inputs, but parameters
131   const AnfNodePtrList get_inputs() const;
parameters()132   const AnfNodePtrList &parameters() const { return parameters_; }
133   // Append
134   virtual ParameterPtr add_parameter();
135   ParameterPtr add_parameter(NodeDebugInfoPtr &&debug_info);
136   void add_parameter(const ParameterPtr &param);
append_parameter(const ParameterPtr & p)137   void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
138   // Prepend
139   virtual ParameterPtr InsertFrontParameter();
140   void InsertFrontParameter(const ParameterPtr &param);
PrependParameter(const ParameterPtr & p)141   void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
set_parameters(const AnfNodePtrList & params)142   void set_parameters(const AnfNodePtrList &params) { parameters_ = params; }
set_parameters(AnfNodePtrList && params)143   void set_parameters(AnfNodePtrList &&params) { parameters_ = std::move(params); }
144   // Add a FV weight parameter with specific name.
145   ParameterPtr AddFvParameter(const std::string &name, const ValuePtr &default_value);
146 
147   // Create a CNode with given inputs, bound to this graph.
148   virtual CNodePtr NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs);
149   virtual CNodePtr NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs);
150 
151   // @deprecated
152   // To use 'CNodePtr NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs)' instead.
153   virtual CNodePtr NewCNode(AnfNodePtrList &&inputs);
154   // @deprecated
155   // To use 'CNodePtr NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs)' instead.
156   virtual CNodePtr NewCNode(const AnfNodePtrList &inputs);
157 
158   CNodePtr NewCNode(const PrimitivePtr &primitive, const AnfNodePtrList &inputs);
159 
160   // Create a CNode with given weak inputs, bound to this graph and push back to order list.
161   CNodePtr NewCNodeInOrderWeak(AnfNodeWeakPtrList &&weak_inputs);
162   CNodePtr NewCNodeInOrderWeak(const AnfNodeWeakPtrList &weak_inputs);
163 
164   // Create a CNode with given inputs, bound to this graph and push back to order list.
165   CNodePtr NewCNodeInOrder(AnfNodePtrList &&inputs);
166   CNodePtr NewCNodeInOrder(const AnfNodePtrList &inputs = AnfNodePtrList());
167   CNodePtr NewCNodeInOrder(const PrimitivePtr &primitive, const AnfNodePtrList &inputs);
168 
169   // Create a CNode with given inputs, bound to this graph and push back to front of order list.
170   CNodePtr NewCNodeInFront(const AnfNodePtrList &inputs = AnfNodePtrList());
171 
172   // Create a CNode with given inputs, put it to order list before the position node.
173   CNodePtr NewCNodeBefore(const AnfNodePtr &position, const AnfNodePtrList &inputs);
174 
175   // Create a CNode with given inputs, put it to order list after the position node.
176   CNodePtr NewCNodeAfter(const AnfNodePtr &position, const AnfNodePtrList &inputs);
177 
178   // Functions for handling variable argument, keyword-only arguments and variable keyword argument.
179   AnfNodePtr GetDefaultValueByName(const std::string &name);
set_param_default_value(const std::string & name,const AnfNodePtr & node)180   void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
181     parameter_default_value_[name] = node;
182   }
183   void SetDefaultValues(const std::vector<std::string> &name_list, const AnfNodePtrList &value_list);
184   void ClearDefaultValues();
185   size_t GetDefaultValueCount();
parameter_default_value()186   std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
set_has_vararg(bool has_)187   void set_has_vararg(bool has_) { has_vararg_ = has_; }
has_vararg()188   bool has_vararg() const { return has_vararg_; }
189   // Parameters are ordered as: Positional Parameters, Kwonlyargs, *Varargs, **Kwargs, HyperParam;
190   AnfNodePtr GetVariableArgParameter();
191   std::string GetVariableArgName();
set_has_kwarg(bool has_)192   void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
has_kwarg()193   bool has_kwarg() const { return has_kwarg_; }
set_kwonlyargs_count(int count)194   void set_kwonlyargs_count(int count) { kw_only_args_count_ = count; }
kwonlyargs_count()195   int kwonlyargs_count() const { return kw_only_args_count_; }
196   AnfNodePtr GetVariableKwargParameter();
197   std::string GetVariableKwargName();
198   AnfNodePtrList GetKwOnlyArgsParameters();
set_fv_param_count(size_t count)199   void set_fv_param_count(size_t count) { fv_param_count_ = count; }
fv_param_count()200   size_t fv_param_count() const { return fv_param_count_; }
201   int GetPositionalArgsCount() const;
202   AnfNodePtr GetParameterByName(const std::string &name);
203   bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
204   FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list);
set_is_generate(bool generated)205   void set_is_generate(bool generated) { is_generated_ = generated; }
is_generated()206   bool is_generated() const { return is_generated_; }
207 
attrs()208   mindspore::HashMap<std::string, ValuePtr> &attrs() { return attrs_; }
set_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)209   void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
210     for (auto &attr : attrs) {
211       attrs_[attr.first] = attr.second;
212     }
213   }
214   bool has_flag(const std::string &key) const;
set_flag(const std::string & key,bool flag)215   void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
erase_flag(const std::string & key)216   void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
217 
218   bool has_attr(const std::string &key) const;
219   ValuePtr get_attr(const std::string &key) const;
set_attr(const std::string & key,const ValuePtr & value)220   void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
221 
transforms()222   mindspore::HashMap<std::string, FuncGraphTransform> &transforms() { return transforms_; }
set_transforms(const mindspore::HashMap<std::string,FuncGraphTransform> & transforms)223   void set_transforms(const mindspore::HashMap<std::string, FuncGraphTransform> &transforms) {
224     transforms_ = transforms;
225   }
226 
227   // Return the graph's output, or nullptr if not yet deduced.
228   AnfNodePtr output() const;
229   void set_output(const AnfNodePtr &value, bool force_new_ret = false);
230 
get_return()231   CNodePtr get_return() const { return return_.lock(); }
return_node()232   const CNodePtr return_node() const { return return_.lock(); }
set_return(const CNodePtr & cnode)233   void set_return(const CNodePtr &cnode) {
234     return_owner_ = cnode;
235     return_ = CNodeWeakPtr(cnode);
236   }
ResetReturnOwner()237   void ResetReturnOwner() { return_owner_.reset(); }
238 
239   const std::list<AnfNodePtr> &own_nodes() const;
240   void AddOwnNode(const AnfNodePtr &node);
241   void AddOwnNode(const AnfNodePtrList &nodes);
242   void AddOwnNode(const AnfNodeWeakPtrList &weak_nodes);
243   void RemoveOwnNode(const AnfNodePtr &node);
244   void ResetOwnNodes();
245 
manager()246   FuncGraphManagerPtr manager() const { return manager_.lock(); }
set_manager(const FuncGraphManagerPtr & m)247   void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
248 
249   std::string ToString() const override;
250   GraphDebugInfoPtr debug_info();
set_debug_info(const GraphDebugInfoPtr & info)251   void set_debug_info(const GraphDebugInfoPtr &info) {
252     if (info == nullptr) {
253       MS_LOG(INTERNAL_EXCEPTION) << "Graph set null debug info";
254     }
255     this->debug_info_ = info;
256   }
257   // Get all nodes belonging to this func graph.
258   const AnfNodeSet &nodes() const;
259   const AnfNodeSet &switch_nodes() const;
260   void CopyNodes(const FuncGraphPtr &source);
261   void ClearNodes();
262   void AddNode(const AnfNodePtr &node);
263   void DropNode(const AnfNodePtr &node);
264 
265   // Get all value_nodes belonging to this func graph.
266   const AnfNodeCounterMap &value_nodes() const;
267   void CopyValueNodes(const FuncGraphPtr &source);
268   void ClearValueNodes();
269   void AddValueNode(const AnfNodePtr &node, int count = 1);
270   void DropValueNode(const AnfNodePtr &node);
271 
272   // Get all free vars directly used in this func graph.
273   const AnfNodeCounterMap &free_variables() const;
274   void CopyFreeVariables(const FuncGraphPtr &source);
275   void ClearFreeVariables();
276   bool AddFreeVariable(const AnfNodePtr &node, int count = 1);
277   bool DropFreeVariable(const AnfNodePtr &node);
278 
279   // Get all vars required by this func graph.
280   const BaseRefCounterMap &free_variables_total();
281 
282   // Return the set of graphs free_variables_total belong to.
283   AnfNodePtrList free_variables_nodes();
284 
285   // Get all vars that are func graphs
286   std::vector<FuncGraphPtr> free_variables_func_graphs();
287 
288   // Get all value nodes of func graph directly used by this func graph.
289   const FuncGraphCounterMap &func_graphs_used() const;
290   void CopyFuncGraphsUsed(const FuncGraphPtr &source);
291   void ClearFuncGraphsUsed();
292   bool AddFuncGraphUsed(const FuncGraphPtr &fg, int count = 1);
293   bool DropFuncGraphUsed(const FuncGraphPtr &fg);
294 
295   // Get all value nodes in the inputs of MetaFgPrim directly used by this func graph.
296   const mindspore::HashMap<AnfNodePtr, int> &meta_fg_prim_value_nodes() const;
297   void CopyMetaFgPrimValueNodes(const FuncGraphPtr &source);
298   void ClearMetaFgPrimValueNodes();
299   void AddMetaFgPrimValueNode(const AnfNodePtr &value_node, int count = 1);
300   void DropMetaFgPrimValueNode(const AnfNodePtr &value_node);
301 
302   // Get all func graphs nested used by this func graph.
303   const FuncGraphSet &func_graphs_used_total();
304 
305   // Get all user value nodes of this func graph, by CNode and its input's index.
306   const CNodeIndexCounterMap &func_graph_cnodes_index() const;
307   void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
308   void ClearFuncGraphCNodesIndex();
309   void AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair, int count = 1);
310   void DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair);
311 
312   // Return the parent of this graph.
313   FuncGraphPtr parent();
314 
315   // Return the children of this graph.
316   const FuncGraphSet &children();
317 
318   // Return the scope of this graph, scope have graph self but children not have.
319   const FuncGraphSet &scope();
320 
321   // Return whether this graph is recursive.
322   bool recursive();
323 
324   // Return graphs which forms a recursive loop.
325   std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
326 
hash()327   std::size_t hash() const override { return PointerHash<FuncGraph>{}(this); }
328 
329   bool operator==(const Value &other) const override {
330     if (other.isa<FuncGraph>()) {
331       return &other == this;
332     } else {
333       return false;
334     }
335   }
336   void GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count, int pos_args_input_count,
337                          AnfNodePtrList *specialized_parameter_list,
338                          mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
339 
340   void GenerateKwParams(const FuncGraphPtr &specialized_graph,
341                         const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list, int pos_args_input_count,
342                         AnfNodePtrList *specialized_parameter_list,
343                         mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
344 
345   void GenerateDefaultValue(const FuncGraphPtr &specialized_graph, const AnfNodePtrList &specialized_parameter_list,
346                             mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
347 
parameter_obj_nodes()348   const AnfNodePtrList &parameter_obj_nodes() const { return parameter_obj_nodes_; }
add_parameter_obj_node(const AnfNodePtr & p)349   void add_parameter_obj_node(const AnfNodePtr &p) { parameter_obj_nodes_.push_back(p); }
350 
351   mindspore::HashMap<std::string, ValuePtr> attrs_;
352   mindspore::HashMap<std::string, FuncGraphTransform> transforms_;
353   // Parameter default value.
354   std::map<std::string, AnfNodePtr> parameter_default_value_;
355 
356   SeenNum seen_{0};
357   SeenNum extra_seen_{0};
358 
359   std::list<CNodePtr> GetOrderedCnodes();
360   void EraseUnusedNodeInOrder(const AnfNodePtr &node);
361   void EraseUnusedNodeInOrder();
362   void DumpCNodeList();
order_list()363   const std::list<CNodeWeakPtr> &order_list() const { return order_; }
364 
set_order_list(std::list<CNodeWeakPtr> && order_list)365   void set_order_list(std::list<CNodeWeakPtr> &&order_list) { order_ = std::move(order_list); }
366 
367   // Add a CNode at the end of order list.
AppendOrderList(const CNodePtr & cnode)368   void AppendOrderList(const CNodePtr &cnode) { (void)order_.emplace_back(CNodeWeakPtr(cnode)); }
369 
370   // Prepend CNode at the front of order list.
PrependOrderList(const CNodePtr & cnode)371   void PrependOrderList(const CNodePtr &cnode) { (void)order_.emplace_front(CNodeWeakPtr(cnode)); }
372 
373   // Maintain CNode order list when a CNode is replaced by a new one.
374   void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
375 
376   // Clear CNode order list.
ClearOrderList()377   void ClearOrderList() { order_.clear(); }
378 
stub()379   bool stub() const { return stub_; }
set_stub(bool stub)380   void set_stub(bool stub) { stub_ = stub; }
381 
indirect()382   std::shared_ptr<bool> indirect() {
383     // Lazy initialization.
384     if (!indirect_) {
385       indirect_ = std::make_shared<bool>(false);
386     }
387     return indirect_;
388   }
set_indirect(std::shared_ptr<bool> indirect)389   void set_indirect(std::shared_ptr<bool> indirect) { indirect_ = indirect; }
390 
391   void SetMultiTarget() const;
exist_multi_target()392   bool exist_multi_target() const { return exist_multi_target_; }
set_exist_multi_target(bool exist_multi_target)393   void set_exist_multi_target(bool exist_multi_target) { exist_multi_target_ = exist_multi_target; }
stage()394   int64_t stage() const { return stage_; }
set_stage(int64_t stage)395   void set_stage(int64_t stage) { stage_ = stage; }
segment()396   int64_t segment() const { return segment_; }
set_segment(int64_t segment)397   void set_segment(int64_t segment) { segment_ = segment; }
dynamic_shape()398   bool dynamic_shape() { return dynamic_shape_; }
set_dynamic_shape(bool dynamic_shape)399   void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
400 
dropped()401   bool dropped() const { return dropped_; }
set_dropped(bool dropped)402   void set_dropped(bool dropped) { dropped_ = dropped; }
403 
bprop_hash()404   std::string bprop_hash() const { return bprop_hash_; }
set_bprop_hash(const std::string & bprop_hash)405   void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
406 
bprop_filepath()407   std::string bprop_filepath() const { return bprop_filepath_; }
set_bprop_filepath(const std::string & bprop_filepath)408   void set_bprop_filepath(const std::string &bprop_filepath) { bprop_filepath_ = bprop_filepath; }
409 
modify_output()410   bool modify_output() const { return modify_output_; }
set_modify_output(bool modify_output)411   void set_modify_output(bool modify_output) { modify_output_ = modify_output; }
used_forward_nodes()412   const mindspore::OrderedSet<AnfNodePtr> &used_forward_nodes() const { return used_forward_nodes_; }
413   void set_used_forward_nodes(const AnfNodePtrList &used_forward_nodes);
ClearUsedForwardNodes()414   void ClearUsedForwardNodes() { used_forward_nodes_.clear(); }
415 
is_tensor_condition_branch()416   bool is_tensor_condition_branch() const { return is_tensor_condition_branch_; }
set_is_tensor_condition_branch(bool is_tensor_condition_branch)417   void set_is_tensor_condition_branch(bool is_tensor_condition_branch) {
418     is_tensor_condition_branch_ = is_tensor_condition_branch;
419   }
420 
421   /// \brief Topological sort a graph from the given end node.
422   ///
423   /// \param[in] node The end node of the graph to be sorted.
424   ///
425   /// \return The sorted nodes.
426   static AnfNodePtrList TopoSort(const AnfNodePtr &node);
427 
set_python_obj(const ValuePtr & python_obj)428   void set_python_obj(const ValuePtr &python_obj) { python_obj_ = python_obj; }
python_obj()429   ValuePtr python_obj() const { return python_obj_; }
430 
phase()431   const std::string &phase() const { return phase_; }
432 
set_symbol_engine(const SymbolEnginePtr & se)433   void set_symbol_engine(const SymbolEnginePtr &se) { symbol_engine_ = se; }
symbol_engine()434   const SymbolEnginePtr &symbol_engine() const { return symbol_engine_; }
435 
436   // Only used for func_graph manager to control resource free.
attached_mng_cnt()437   int attached_mng_cnt() const { return attached_mng_cnt_; }
438 
439   // Reserve the func graph, not to release in manager.
set_reserved(bool reserved)440   void set_reserved(bool reserved) { reserved_ = reserved; }
reserved()441   bool reserved() const { return reserved_; }
442 
443  private:
444   // Only used for func_graph manager to control resource free.
IncAttachedMngCnt()445   void IncAttachedMngCnt() { attached_mng_cnt_++; }
DecAttachedMngCnt()446   void DecAttachedMngCnt() { attached_mng_cnt_--; }
447   // Clear all info from manager.
448   void ClearAllResource();
449 
450   // Graph is manipulated by manager and others.
451   friend FuncGraphManager;
452 
453   // All nodes of the function.
454   AnfNodeSet nodes_;
455 
456   // All switch nodes of the function.
457   AnfNodeSet switch_nodes_;
458 
459   // All value nodes of the function.
460   AnfNodeCounterMap value_nodes_;
461 
462   // All func graph value nodes of the function.
463   FuncGraphCounterMap func_graphs_used_;
464 
465   // All free variables of the function.
466   AnfNodeCounterMap free_variables_;
467 
468   // All value nodes calling MetaFgPrim in the function.
469   mindspore::HashMap<AnfNodePtr, int> meta_fg_prim_value_nodes_;
470 
471   // All user value nodes of this func graph, recording by CNode and its input's index.
472   CNodeIndexCounterMap func_graph_cnodes_index_;
473 
474   // Parameters of this function.
475   AnfNodePtrList parameters_;
476   AnfNodePtrList parameter_obj_nodes_;
477 
478   // Whether there is a *args and **kwargs, and count kw_only_args'number.
479   bool has_vararg_;
480   bool has_kwarg_;
481   bool exist_multi_target_;
482   int kw_only_args_count_;
483   // Hyper param is used as free variable and placed on the top graph.
484   // and positioned in the end of the param list, so we record the number to trace the position.
485   size_t fv_param_count_;
486   // Argument input list for the graph used to generate this graph.
487   bool is_generated_;
488   // CNode that calls 'return' primitive.
489   // We use shared pointer to manage it.
490   CNodeWeakPtr return_;
491   // Before release all func graphs in Manager, reset the owner firstly.
492   CNodePtr return_owner_;
493 
494   // Back-ref to its manager.
495   // Hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
496   // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
497   // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
498   // In some ut test cases, they may use local FuncGraphManager in function which
499   // generating the func graph, when go outside of that function, func graph will have no
500   // FuncGraphManager. In that special case, Manage() should be called to make the func graph
501   // managed.
502   std::weak_ptr<FuncGraphManager> manager_;
503   int attached_mng_cnt_{0};
504 
505   GraphDebugInfoPtr debug_info_;
506   void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, const AnfNodePtrList &kwarg_keys_tuple_nodes,
507                              const AnfNodePtrList &kwarg_values_tuple_nodes,
508                              mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
509 
510   // CNode order which relates to origin code order.
511   std::list<CNodeWeakPtr> order_;
512   bool stub_;
513 
514   // The graph is used as some input of Switch, SwitchLayer, or Partial.
515   std::shared_ptr<bool> indirect_;
516 
517   int64_t stage_;
518   int64_t segment_;
519   bool dynamic_shape_ = false;
520   std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
521                      abstract::AbstractBasePtrListEqual>
522     func_graph_cache_;
523 
524   // If the graph was changed, it should be dropped in cache data_converter::object_map_
525   // which used by ConvertToFuncGraph.
526   bool dropped_{false};
527   // If the graph is a bprop graph, it should has a hash of the bprop function.
528   std::string bprop_hash_;
529   // If the graph is a bprop graph, it should has a filepath of the bprop function.
530   std::string bprop_filepath_;
531 
532   // If the graph is decorated with @jit and runs grad process in pynative mode,
533   // forward nodes used in grad graph will be added to output for holding output values.
534   bool modify_output_{false};
535   mindspore::OrderedSet<AnfNodePtr> used_forward_nodes_;
536   // If the func_graph is input of switch node, and the condition of switch is AbstractTensor, need set true.
537   bool is_tensor_condition_branch_{false};
538   // Corresponding python obj.
539   ValuePtr python_obj_{nullptr};
540   std::string phase_;
541   // Own all nodes in the func graph.
542   std::list<AnfNodePtr> own_nodes_;
543   // the manager of symbolic shape's symbols and operations.
544   SymbolEnginePtr symbol_engine_;
545   // Reserve the func graph, not to release in manager.
546   bool reserved_{false};
547 };
548 
NewCNode(const AnfNodePtrList & inputs,const FuncGraphPtr & fg)549 inline CNodePtr NewCNode(const AnfNodePtrList &inputs, const FuncGraphPtr &fg) {
550   MS_EXCEPTION_IF_NULL(fg);
551   return fg->NewCNode(inputs);
552 }
553 
NewCNode(AnfNodePtrList && inputs,const FuncGraphPtr & fg)554 inline CNodePtr NewCNode(AnfNodePtrList &&inputs, const FuncGraphPtr &fg) {
555   MS_EXCEPTION_IF_NULL(fg);
556   return fg->NewCNode(std::move(inputs));
557 }
558 
559 MS_CORE_API SeenNum NewFgSeenGeneration();
560 
561 // Find the root cnodes of a segment of cnodes.
562 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
563 // Find the leaf cnodes of a segment of cnodes.
564 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
565 }  // namespace mindspore
566 
567 #endif  // MINDSPORE_CORE_IR_FUNC_GRAPH_H_
568