• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_H_
20 #define MINDSPORE_CORE_IR_FUNC_GRAPH_H_
21 
22 #include <set>
23 #include <map>
24 #include <string>
25 #include <vector>
26 #include <list>
27 #include <memory>
28 #include <unordered_map>
29 #include <unordered_set>
30 #include <functional>
31 #include <utility>
32 
33 #include "ir/anf.h"
34 #include "ir/manager.h"
35 #include "utils/ordered_set.h"
36 #include "utils/ordered_map.h"
37 #include "base/base_ref.h"
38 #include "base/effect_info.h"
39 #include "ir/func_graph_cloner.h"
40 #include "abstract/abstract_value.h"
41 #include "api/ir/func_graph.h"
42 
43 namespace mindspore {
44 using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
45 using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
46 
47 struct CNodeIndexHasher {
operatorCNodeIndexHasher48   std::size_t operator()(const CNodeIndexPairPtr pair) const {
49     MS_EXCEPTION_IF_NULL(pair);
50     MS_EXCEPTION_IF_NULL(pair->first);
51     return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
52   }
53 };
54 
55 struct CNodeIndexEqual {
operatorCNodeIndexEqual56   bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
57     if (lhs == nullptr || rhs == nullptr) {
58       return false;
59     }
60     if (lhs == rhs) {
61       return true;
62     }
63     if (lhs->first != rhs->first) {
64       return false;
65     }
66     if (lhs->second != rhs->second) {
67       return false;
68     }
69     return true;
70   }
71 };
72 
73 template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
74 using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
75 using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
76 using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
77 
78 using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
79 
80 const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
81 const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
82 const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
83 const char FUNC_GRAPH_FLAG_CORE[] = "core";
84 const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
85 const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
86 const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
87 const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline";
88 
89 const char kFuncGraphFlagUndetermined[] = "Undeterminate";
90 const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
91 const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad";
92 const char kFuncGraphFlagRecursive[] = "Recursive";
93 
94 namespace abstract {
95 class AbstractKeywordArg;
96 using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
97 class AbstractFunction;
98 using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>;
99 }  // namespace abstract
100 
101 // ANF transform class.
102 // Either a primitive or a func_graph.
103 class FuncGraphTransform {
104  public:
105   enum Type { kGtPrimitive, kGtFuncGraph };
106 
107   explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr)
prim_(prim)108       : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
109 
110   explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_)
prim_(prim)111       : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
112 
FuncGraphTransform(const FuncGraphTransform & t)113   FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {}
114 
115   ~FuncGraphTransform() = default;
116 
type()117   Type type() const {
118     if (IsFuncGraph()) {
119       return kGtFuncGraph;
120     } else {
121       return kGtPrimitive;
122     }
123   }
124 
IsPrimitive()125   bool IsPrimitive() const { return (func_graph_.lock() == nullptr); }
IsFuncGraph()126   bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); }
func_graph()127   FuncGraphPtr func_graph() const { return func_graph_.lock(); }
primitive()128   PrimitivePtr primitive() const { return prim_; }
129 
130   FuncGraphTransform &operator=(const FuncGraphTransform &t) {
131     if (this != &t) {
132       prim_ = t.prim_;
133       func_graph_ = t.func_graph_;
134     }
135     return *this;
136   }
137 
138  private:
139   PrimitivePtr prim_;
140   // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here.
141   // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in
142   // FPropRemapper::FinalizeGraph().
143   FuncGraphWeakPtr func_graph_;
144   static const PrimitivePtr func_graph_prim_;
145 };
146 
147 class FuncGraphBase : public Value {
148  public:
149   FuncGraphBase() = default;
150 
151   ~FuncGraphBase() override = default;
152   MS_DECLARE_PARENT(FuncGraphBase, Value);
153 };
154 
155 class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
156  public:
157   FuncGraph();
158   using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
159 
160   ~FuncGraph() override = default;
161   MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
162 
163   // Get the graph's abstract.
164   abstract::AbstractFunctionPtr abstract();
165   abstract::AbstractBasePtr ToAbstract() override;
166 
167   // get function graph inputs, but parameters
168   const std::vector<AnfNodePtr> get_inputs() const final;
169   // Return the graph's output, or nullptr if not yet deduced.
170   AnfNodePtr output() const;
171   void set_output(const AnfNodePtr &value, bool force_new_ret = false);
172 
parameters()173   const std::vector<AnfNodePtr> &parameters() const final { return parameters_; }
174   // Append
175   ParameterPtr add_parameter() override;
176   void add_parameter(const ParameterPtr &p) final;
append_parameter(const ParameterPtr & p)177   void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
178   // Prepend
179   virtual ParameterPtr InsertFrontParameter();
180   void InsertFrontParameter(const ParameterPtr &p);
PrependParameter(const ParameterPtr & p)181   void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
set_parameters(const std::vector<AnfNodePtr> & params)182   void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
183   // Add a weight parameter with specific name.
184   ParameterPtr AddWeightParameter(const std::string &name);
185 
186   // Create a cnode with given inputs, bound to this graph.
187   CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) override;
188   CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) final;
189 
190   // Create a cnode with given inputs, bound to this graph and push back to order list.
191   CNodePtr NewCNodeInOrder(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
192   CNodePtr NewCNodeInOrder(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
193 
194   // Create a cnode with given inputs, bound to this graph and push back to front of order list.
195   CNodePtr NewCNodeInFront(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
196 
197   // Create a cnode with given inputs, put it to order list before the position node.
198   CNodePtr NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
199 
200   // Create a cnode with given inputs, put it to order list after the position node.
201   CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
202 
203   virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor);
204   // Functions for handling variable argument, keyword-only arguments and variable keyword argument.
205   AnfNodePtr GetDefaultValueByName(const std::string &name);
set_param_default_value(const std::string & name,const AnfNodePtr & node)206   void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
207     parameter_default_value_[name] = node;
208   }
209   void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
210   void ClearDefaultValues();
211   size_t GetDefaultValueCount();
parameter_default_value()212   std::map<std::string, AnfNodePtr> &parameter_default_value() { return parameter_default_value_; }
set_has_vararg(bool has_)213   void set_has_vararg(bool has_) { has_vararg_ = has_; }
has_vararg()214   bool has_vararg() const { return has_vararg_; }
215   AnfNodePtr GetVariableArgParameter();
216   std::string GetVariableArgName();
set_has_kwarg(bool has_)217   void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
has_kwarg()218   bool has_kwarg() const { return has_kwarg_; }
set_kwonlyargs_count(int count)219   void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
kwonlyargs_count()220   int kwonlyargs_count() const { return kwonlyargs_count_; }
221   AnfNodePtr GetVariableKwargParameter();
222   std::string GetVariableKwargName();
set_hyper_param_count(size_t count)223   void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
hyper_param_count()224   size_t hyper_param_count() const { return hyper_param_count_; }
225   int GetPositionalArgsCount() const;
226   AnfNodePtr GetParameterByName(const std::string &name);
227   bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
228   FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
set_is_generate(bool generated)229   void set_is_generate(bool generated) { is_generated_ = generated; }
is_generated()230   bool is_generated() const { return is_generated_; }
set_is_bprop(bool is_brop)231   void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; }
is_bprop()232   bool is_bprop() const { return is_bprop_; }
233 
attrs()234   std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
set_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)235   void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
236     for (auto &attr : attrs) {
237       attrs_[attr.first] = attr.second;
238     }
239   }
240   bool has_flag(const std::string &key);
set_flag(const std::string & key,bool flag)241   void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
erase_flag(const std::string & key)242   void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
243 
244   bool has_attr(const std::string &key) const final;
245   ValuePtr get_attr(const std::string &key) const final;
set_attr(const std::string & key,const ValuePtr & value)246   void set_attr(const std::string &key, const ValuePtr &value) final { attrs_[key] = value; }
247 
transforms()248   std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
set_transforms(const std::unordered_map<std::string,FuncGraphTransform> & transforms)249   void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
250     transforms_ = transforms;
251   }
252 
get_return()253   CNodePtr get_return() const final { return return_; }
set_return(const CNodePtr & cnode)254   void set_return(const CNodePtr &cnode) final { return_ = cnode; }
255 
manager()256   FuncGraphManagerPtr manager() const { return manager_.lock(); }
set_manager(const FuncGraphManagerPtr & m)257   void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
258 
get_manager()259   api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
260 
261   std::string ToString() const override;
262   GraphDebugInfoPtr debug_info();
set_debug_info(const GraphDebugInfoPtr & info)263   void set_debug_info(const GraphDebugInfoPtr &info) {
264     if (info == nullptr) {
265       MS_LOG(EXCEPTION) << "Graph set null debug info";
266     }
267     this->debug_info_ = info;
268   }
269   // Get all nodes belonging to this func graph.
270   const AnfNodeSet &nodes() const final;
271   void CopyNodes(const FuncGraphPtr &source);
272   void ClearNodes();
273   void AddNode(const AnfNodePtr &node);
274   void DropNode(const AnfNodePtr &node);
275 
276   // Get all value_nodes belonging to this func graph.
277   const AnfNodeCounterMap &value_nodes() const;
278   void CopyValueNodes(const FuncGraphPtr &source);
279   void ClearValueNodes();
280   void AddValueNode(const AnfNodePtr &node, int count = 1);
281   void DropValueNode(const AnfNodePtr &node);
282 
283   // Get all free vars directly used in this func graph.
284   const AnfNodeCounterMap &free_variables() const;
285   void CopyFreeVariables(const FuncGraphPtr &source);
286   void ClearFreeVariables();
287   bool AddFreeVariable(const AnfNodePtr &node, int count = 1);
288   bool DropFreeVariable(const AnfNodePtr &node);
289 
290   // Get all vars required by this func graph.
291   const BaseRefCounterMap &free_variables_total();
292 
293   // Return the set of graphs free_variables_total belong to.
294   std::vector<AnfNodePtr> free_variables_nodes();
295 
296   // Get all vars that are func graphs
297   std::vector<FuncGraphPtr> free_variables_func_graphs();
298 
299   // Get all value nodes of func graph directly used by this func graph.
300   const FuncGraphCounterMap &func_graphs_used() const;
301   void CopyFuncGraphsUsed(const FuncGraphPtr &source);
302   void ClearFuncGraphsUsed();
303   bool AddFuncGraphUsed(const FuncGraphPtr &fg, int count = 1);
304   bool DropFuncGraphUsed(const FuncGraphPtr &fg);
305 
306   // Get all value nodes in the inputs of J directly used by this func graph.
307   const std::unordered_map<AnfNodePtr, int> &j_value_nodes() const;
308   void CopyJValueNodes(const FuncGraphPtr &source);
309   void ClearJValueNodes();
310   void AddJValueNode(const AnfNodePtr &value_node, int count = 1);
311   void DropJValueNode(const AnfNodePtr &value_node);
312 
313   // Get all func graphs nested used by this func graph.
314   const FuncGraphSet &func_graphs_used_total();
315 
316   // Get all user value nodes of this func graph, by CNode and its input's index.
317   const CNodeIndexCounterMap &func_graph_cnodes_index() const;
318   void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
319   void ClearFuncGraphCNodesIndex();
320   void AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &node, int count = 1);
321   void DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &node);
322 
323   // Return the parent of this graph.
324   FuncGraphPtr parent();
325 
326   // Return the children of this graph.
327   const FuncGraphSet &children();
328 
329   // Return the scope of this graph, scope have graph self but children not have.
330   const FuncGraphSet &scope();
331 
332   // Return whether this graph is recursive.
333   bool recursive();
334 
335   // Return graphs which forms a recursive loop.
336   std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
337 
hash()338   std::size_t hash() const override { return std::hash<const FuncGraph *>{}(this); }
339 
340   void DumpFuncGraph(const std::string &path = "./func_graph.dot");
341 
342   bool operator==(const Value &other) const override {
343     if (other.isa<FuncGraph>()) {
344       return &other == this;
345     } else {
346       return false;
347     }
348   }
349   void GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count, int pos_args_input_count,
350                          std::vector<AnfNodePtr> *specialized_parameter_list,
351                          std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
352 
353   void GenerateKwParams(const FuncGraphPtr &specialized_graph,
354                         const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
355                         std::vector<AnfNodePtr> *specialized_parameter_list,
356                         std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
357 
358   void GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
359                             const std::vector<AnfNodePtr> &specialized_parameter_list,
360                             std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
361 
paramter_obj_nodes()362   const std::vector<AnfNodePtr> &paramter_obj_nodes() const { return paramter_obj_nodes_; }
add_parameter_obj_node(const AnfNodePtr & p)363   void add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
364 
365   std::unordered_map<std::string, ValuePtr> attrs_;
366   std::unordered_map<std::string, FuncGraphTransform> transforms_;
367   // Parameter default value.
368   std::map<std::string, AnfNodePtr> parameter_default_value_;
369   size_t seen_;
370 
371   std::list<CNodePtr> GetOrderedCnodes();
372   void EraseUnusedNodeInOrder(const AnfNodePtr &n);
373   void EraseUnusedNodeInOrder();
374   void DumpCNodeList();
order_list()375   const OrderedSet<CNodePtr> &order_list() const { return order_; }
376 
set_order_list(OrderedSet<CNodePtr> && order_list)377   void set_order_list(OrderedSet<CNodePtr> &&order_list) { order_ = std::move(order_list); }
378 
379   // Add a cnode at the end of order list.
AppendOrderList(const CNodePtr & cnode)380   void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); }
381 
382   // Prepend cnode at the front of order list.
PrependOrderList(const CNodePtr & cnode)383   void PrependOrderList(const CNodePtr &cnode) { order_.push_front(cnode); }
384 
385   // Maintain cnode order list when a cnode is replaced by a new one.
386   void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
387 
388   // Clear cnode order list.
ClearOrderList()389   void ClearOrderList() { order_.clear(); }
390 
stub()391   bool stub() const { return stub_; }
set_stub(bool stub)392   void set_stub(bool stub) { stub_ = stub; }
set_drawer(Drawer drawer)393   static void set_drawer(Drawer drawer) { drawer_ = drawer; }
switch_input()394   std::shared_ptr<bool> switch_input() const { return switch_input_; }
set_switch_input(std::shared_ptr<bool> switch_input)395   void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; }
switch_layer_input()396   std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
set_switch_layer_input(std::shared_ptr<bool> switch_layer_input)397   void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
398   bool ContainMultiTarget() const;
stage()399   int64_t stage() { return stage_; }
set_stage(int64_t stage)400   void set_stage(int64_t stage) { stage_ = stage; }
401 
dropped()402   bool dropped() const { return dropped_; }
set_dropped(bool dropped)403   void set_dropped(bool dropped) { dropped_ = dropped; }
404 
bprop_hash()405   std::string bprop_hash() const { return bprop_hash_; }
set_bprop_hash(const std::string & bprop_hash)406   void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
407 
modify_output()408   bool modify_output() const { return modify_output_; }
set_modify_output(bool modify_output)409   void set_modify_output(bool modify_output) { modify_output_ = modify_output; }
used_forward_nodes()410   const std::unordered_set<AnfNodePtr> &used_forward_nodes() const { return used_forward_nodes_; }
411   void set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes);
ClearUsedForwardNodes()412   void ClearUsedForwardNodes() { used_forward_nodes_.clear(); }
413 
414  private:
415   // Only used for func_graph manager to control resource free.
attached_mng_cnt()416   int attached_mng_cnt() const { return attached_mng_cnt_; }
IncAttachedMngCnt()417   void IncAttachedMngCnt() { attached_mng_cnt_++; }
DecAttachedMngCnt()418   void DecAttachedMngCnt() { attached_mng_cnt_--; }
419   // Clear all info from manager.
420   void ClearAllManagerInfo();
421 
422   // Graph is manipulated by manager and others.
423   friend FuncGraphManager;
424 
425   // All nodes of the function.
426   AnfNodeSet nodes_;
427 
428   // All value nodes of the function.
429   AnfNodeCounterMap value_nodes_;
430 
431   // All func graph value nodes of the function.
432   FuncGraphCounterMap func_graphs_used_;
433 
434   // All free variables of the function.
435   AnfNodeCounterMap free_variables_;
436 
437   // All value nodes calling J in the function.
438   std::unordered_map<AnfNodePtr, int> j_value_nodes_;
439 
440   // All user value nodes of this func graph, recording by CNode and its input's index.
441   CNodeIndexCounterMap func_graph_cnodes_index_;
442 
443   // Parameters of this function.
444   std::vector<AnfNodePtr> parameters_;
445   std::vector<AnfNodePtr> paramter_obj_nodes_;
446 
447   // Whether there is a *args and **kwargs, and count kwonlyargs'number.
448   bool has_vararg_;
449   bool has_kwarg_;
450   int kwonlyargs_count_;
451   // Hyper param is placed on the top graph,
452   // and positioned in the end of the param list, so we record the number to trace the position.
453   size_t hyper_param_count_;
454   // Argument input list for the graph used to generate this graph.
455   bool is_generated_;
456 
457   bool is_bprop_;
458 
459   // CNode that calls 'return' primitive.
460   // We use shared pointer to manage it.
461   CNodePtr return_;
462 
463   // Back-ref to its manager.
464   // Hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
465   // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
466   // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
467   // In some ut test cases, they may use local FuncGraphManager in function which
468   // generating the func graph, when go outside of that function, func graph will have no
469   // FuncGraphManager. In that special case, Manage() should be called to make the func graph
470   // managed.
471   std::weak_ptr<FuncGraphManager> manager_;
472   int attached_mng_cnt_ = 0;
473 
474   GraphDebugInfoPtr debug_info_;
475   void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
476                              const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
477                              const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes,
478                              std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
479 
480   // CNode order which relates to origin code order.
481   OrderedSet<CNodePtr> order_;
482   bool stub_;
483   inline static Drawer drawer_ = nullptr;
484   // Design switch_input and switch_layer_input as a ptr to
485   // share between derived backpropagator and cloned graphs.
486   std::shared_ptr<bool> switch_input_;
487   std::shared_ptr<bool> switch_layer_input_;
488   int64_t stage_;
489   std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
490                      abstract::AbstractBasePtrListEqual>
491     func_graph_cache_;
492 
493   // If the graph was changed, it should be dropped in cache data_converter::object_map_
494   // which used by ConvertToFuncGraph.
495   bool dropped_ = false;
496   // If the graph is a bprop graph, it should has a hash of the bprop directory.
497   std::string bprop_hash_;
498 
499   // If the graph is decorated by @ms_function and runs grad process in pynative mode,
500   // forward nodes used in grad graph will be added to output for holding output values.
501   bool modify_output_ = false;
502   std::unordered_set<AnfNodePtr> used_forward_nodes_;
503 };
504 
NewCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & fg)505 inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
506   MS_EXCEPTION_IF_NULL(fg);
507   return fg->NewCNode(inputs);
508 }
509 
510 size_t NewFgSeenGeneration();
511 
512 // Find the root cnodes of a segment of cnodes.
513 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
514 // Find the leaf cnodes of a segment of cnodes.
515 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
516 }  // namespace mindspore
517 
518 #endif  // MINDSPORE_CORE_IR_FUNC_GRAPH_H_
519