1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_H_
20 #define MINDSPORE_CORE_IR_FUNC_GRAPH_H_
21
22 #include <set>
23 #include <map>
24 #include <string>
25 #include <vector>
26 #include <list>
27 #include <memory>
28 #include <unordered_map>
29 #include <unordered_set>
30 #include <functional>
31 #include <utility>
32
33 #include "ir/anf.h"
34 #include "ir/manager.h"
35 #include "utils/ordered_set.h"
36 #include "utils/ordered_map.h"
37 #include "base/base_ref.h"
38 #include "base/effect_info.h"
39 #include "ir/func_graph_cloner.h"
40 #include "abstract/abstract_value.h"
41 #include "api/ir/func_graph.h"
42
43 namespace mindspore {
44 using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
45 using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
46
47 struct CNodeIndexHasher {
operatorCNodeIndexHasher48 std::size_t operator()(const CNodeIndexPairPtr pair) const {
49 MS_EXCEPTION_IF_NULL(pair);
50 MS_EXCEPTION_IF_NULL(pair->first);
51 return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
52 }
53 };
54
55 struct CNodeIndexEqual {
operatorCNodeIndexEqual56 bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
57 if (lhs == nullptr || rhs == nullptr) {
58 return false;
59 }
60 if (lhs == rhs) {
61 return true;
62 }
63 if (lhs->first != rhs->first) {
64 return false;
65 }
66 if (lhs->second != rhs->second) {
67 return false;
68 }
69 return true;
70 }
71 };
72
73 template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
74 using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
75 using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
76 using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
77
78 using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
79
80 const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
81 const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
82 const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
83 const char FUNC_GRAPH_FLAG_CORE[] = "core";
84 const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
85 const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
86 const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
87 const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline";
88
89 const char kFuncGraphFlagUndetermined[] = "Undeterminate";
90 const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
91 const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad";
92 const char kFuncGraphFlagRecursive[] = "Recursive";
93
94 namespace abstract {
95 class AbstractKeywordArg;
96 using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
97 class AbstractFunction;
98 using AbstractFunctionPtr = std::shared_ptr<AbstractFunction>;
99 } // namespace abstract
100
101 // ANF transform class.
102 // Either a primitive or a func_graph.
103 class FuncGraphTransform {
104 public:
105 enum Type { kGtPrimitive, kGtFuncGraph };
106
107 explicit FuncGraphTransform(const PrimitivePtr prim, const FuncGraphPtr func_graph = nullptr)
prim_(prim)108 : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
109
110 explicit FuncGraphTransform(const FuncGraphPtr &func_graph, const PrimitivePtr &prim = func_graph_prim_)
prim_(prim)111 : prim_(prim), func_graph_(FuncGraphWeakPtr(func_graph)) {}
112
FuncGraphTransform(const FuncGraphTransform & t)113 FuncGraphTransform(const FuncGraphTransform &t) : prim_(t.prim_), func_graph_(t.func_graph_) {}
114
115 ~FuncGraphTransform() = default;
116
type()117 Type type() const {
118 if (IsFuncGraph()) {
119 return kGtFuncGraph;
120 } else {
121 return kGtPrimitive;
122 }
123 }
124
IsPrimitive()125 bool IsPrimitive() const { return (func_graph_.lock() == nullptr); }
IsFuncGraph()126 bool IsFuncGraph() const { return (func_graph_.lock() != nullptr); }
func_graph()127 FuncGraphPtr func_graph() const { return func_graph_.lock(); }
primitive()128 PrimitivePtr primitive() const { return prim_; }
129
130 FuncGraphTransform &operator=(const FuncGraphTransform &t) {
131 if (this != &t) {
132 prim_ = t.prim_;
133 func_graph_ = t.func_graph_;
134 }
135 return *this;
136 }
137
138 private:
139 PrimitivePtr prim_;
140 // FuncGraph will be hold by FuncGraphManager, so weak_ptr is enough here.
141 // And use weak_ptr can break the reference cycle between "primal" and "grad" graph in
142 // FPropRemapper::FinalizeGraph().
143 FuncGraphWeakPtr func_graph_;
144 static const PrimitivePtr func_graph_prim_;
145 };
146
147 class FuncGraphBase : public Value {
148 public:
149 FuncGraphBase() = default;
150
151 ~FuncGraphBase() override = default;
152 MS_DECLARE_PARENT(FuncGraphBase, Value);
153 };
154
155 class FuncGraph : public api::FuncGraph, public FuncGraphBase, public EffectInfoHolder {
156 public:
157 FuncGraph();
158 using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
159
160 ~FuncGraph() override = default;
161 MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
162
163 // Get the graph's abstract.
164 abstract::AbstractFunctionPtr abstract();
165 abstract::AbstractBasePtr ToAbstract() override;
166
167 // get function graph inputs, but parameters
168 const std::vector<AnfNodePtr> get_inputs() const final;
169 // Return the graph's output, or nullptr if not yet deduced.
170 AnfNodePtr output() const;
171 void set_output(const AnfNodePtr &value, bool force_new_ret = false);
172
parameters()173 const std::vector<AnfNodePtr> ¶meters() const final { return parameters_; }
174 // Append
175 ParameterPtr add_parameter() override;
176 void add_parameter(const ParameterPtr &p) final;
append_parameter(const ParameterPtr & p)177 void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
178 // Prepend
179 virtual ParameterPtr InsertFrontParameter();
180 void InsertFrontParameter(const ParameterPtr &p);
PrependParameter(const ParameterPtr & p)181 void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
set_parameters(const std::vector<AnfNodePtr> & params)182 void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; }
183 // Add a weight parameter with specific name.
184 ParameterPtr AddWeightParameter(const std::string &name);
185
186 // Create a cnode with given inputs, bound to this graph.
187 CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>()) override;
188 CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs) final;
189
190 // Create a cnode with given inputs, bound to this graph and push back to order list.
191 CNodePtr NewCNodeInOrder(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
192 CNodePtr NewCNodeInOrder(const PrimitivePtr &primitive, const std::vector<AnfNodePtr> &prim_inputs);
193
194 // Create a cnode with given inputs, bound to this graph and push back to front of order list.
195 CNodePtr NewCNodeInFront(const std::vector<AnfNodePtr> &inputs = std::vector<AnfNodePtr>());
196
197 // Create a cnode with given inputs, put it to order list before the position node.
198 CNodePtr NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
199
200 // Create a cnode with given inputs, put it to order list after the position node.
201 CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs);
202
203 virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor);
204 // Functions for handling variable argument, keyword-only arguments and variable keyword argument.
205 AnfNodePtr GetDefaultValueByName(const std::string &name);
set_param_default_value(const std::string & name,const AnfNodePtr & node)206 void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
207 parameter_default_value_[name] = node;
208 }
209 void SetDefaultValues(const std::vector<std::string> &name_list, const std::vector<AnfNodePtr> &value_list);
210 void ClearDefaultValues();
211 size_t GetDefaultValueCount();
parameter_default_value()212 std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; }
set_has_vararg(bool has_)213 void set_has_vararg(bool has_) { has_vararg_ = has_; }
has_vararg()214 bool has_vararg() const { return has_vararg_; }
215 AnfNodePtr GetVariableArgParameter();
216 std::string GetVariableArgName();
set_has_kwarg(bool has_)217 void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
has_kwarg()218 bool has_kwarg() const { return has_kwarg_; }
set_kwonlyargs_count(int count)219 void set_kwonlyargs_count(int count) { kwonlyargs_count_ = count; }
kwonlyargs_count()220 int kwonlyargs_count() const { return kwonlyargs_count_; }
221 AnfNodePtr GetVariableKwargParameter();
222 std::string GetVariableKwargName();
set_hyper_param_count(size_t count)223 void set_hyper_param_count(size_t count) { hyper_param_count_ = count; }
hyper_param_count()224 size_t hyper_param_count() const { return hyper_param_count_; }
225 int GetPositionalArgsCount() const;
226 AnfNodePtr GetParameterByName(const std::string &name);
227 bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
228 FuncGraphPtr GenerateGraph(const AbstractBasePtrList &args_spec_list);
set_is_generate(bool generated)229 void set_is_generate(bool generated) { is_generated_ = generated; }
is_generated()230 bool is_generated() const { return is_generated_; }
set_is_bprop(bool is_brop)231 void set_is_bprop(bool is_brop) { is_bprop_ = is_brop; }
is_bprop()232 bool is_bprop() const { return is_bprop_; }
233
attrs()234 std::unordered_map<std::string, ValuePtr> &attrs() { return attrs_; }
set_attrs(const std::unordered_map<std::string,ValuePtr> & attrs)235 void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
236 for (auto &attr : attrs) {
237 attrs_[attr.first] = attr.second;
238 }
239 }
240 bool has_flag(const std::string &key);
set_flag(const std::string & key,bool flag)241 void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
erase_flag(const std::string & key)242 void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
243
244 bool has_attr(const std::string &key) const final;
245 ValuePtr get_attr(const std::string &key) const final;
set_attr(const std::string & key,const ValuePtr & value)246 void set_attr(const std::string &key, const ValuePtr &value) final { attrs_[key] = value; }
247
transforms()248 std::unordered_map<std::string, FuncGraphTransform> &transforms() { return transforms_; }
set_transforms(const std::unordered_map<std::string,FuncGraphTransform> & transforms)249 void set_transforms(const std::unordered_map<std::string, FuncGraphTransform> &transforms) {
250 transforms_ = transforms;
251 }
252
get_return()253 CNodePtr get_return() const final { return return_; }
set_return(const CNodePtr & cnode)254 void set_return(const CNodePtr &cnode) final { return_ = cnode; }
255
manager()256 FuncGraphManagerPtr manager() const { return manager_.lock(); }
set_manager(const FuncGraphManagerPtr & m)257 void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
258
get_manager()259 api::FuncGraphManagerPtr get_manager() const final { return manager_.lock(); }
260
261 std::string ToString() const override;
262 GraphDebugInfoPtr debug_info();
set_debug_info(const GraphDebugInfoPtr & info)263 void set_debug_info(const GraphDebugInfoPtr &info) {
264 if (info == nullptr) {
265 MS_LOG(EXCEPTION) << "Graph set null debug info";
266 }
267 this->debug_info_ = info;
268 }
269 // Get all nodes belonging to this func graph.
270 const AnfNodeSet &nodes() const final;
271 void CopyNodes(const FuncGraphPtr &source);
272 void ClearNodes();
273 void AddNode(const AnfNodePtr &node);
274 void DropNode(const AnfNodePtr &node);
275
276 // Get all value_nodes belonging to this func graph.
277 const AnfNodeCounterMap &value_nodes() const;
278 void CopyValueNodes(const FuncGraphPtr &source);
279 void ClearValueNodes();
280 void AddValueNode(const AnfNodePtr &node, int count = 1);
281 void DropValueNode(const AnfNodePtr &node);
282
283 // Get all free vars directly used in this func graph.
284 const AnfNodeCounterMap &free_variables() const;
285 void CopyFreeVariables(const FuncGraphPtr &source);
286 void ClearFreeVariables();
287 bool AddFreeVariable(const AnfNodePtr &node, int count = 1);
288 bool DropFreeVariable(const AnfNodePtr &node);
289
290 // Get all vars required by this func graph.
291 const BaseRefCounterMap &free_variables_total();
292
293 // Return the set of graphs free_variables_total belong to.
294 std::vector<AnfNodePtr> free_variables_nodes();
295
296 // Get all vars that are func graphs
297 std::vector<FuncGraphPtr> free_variables_func_graphs();
298
299 // Get all value nodes of func graph directly used by this func graph.
300 const FuncGraphCounterMap &func_graphs_used() const;
301 void CopyFuncGraphsUsed(const FuncGraphPtr &source);
302 void ClearFuncGraphsUsed();
303 bool AddFuncGraphUsed(const FuncGraphPtr &fg, int count = 1);
304 bool DropFuncGraphUsed(const FuncGraphPtr &fg);
305
306 // Get all value nodes in the inputs of J directly used by this func graph.
307 const std::unordered_map<AnfNodePtr, int> &j_value_nodes() const;
308 void CopyJValueNodes(const FuncGraphPtr &source);
309 void ClearJValueNodes();
310 void AddJValueNode(const AnfNodePtr &value_node, int count = 1);
311 void DropJValueNode(const AnfNodePtr &value_node);
312
313 // Get all func graphs nested used by this func graph.
314 const FuncGraphSet &func_graphs_used_total();
315
316 // Get all user value nodes of this func graph, by CNode and its input's index.
317 const CNodeIndexCounterMap &func_graph_cnodes_index() const;
318 void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
319 void ClearFuncGraphCNodesIndex();
320 void AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &node, int count = 1);
321 void DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &node);
322
323 // Return the parent of this graph.
324 FuncGraphPtr parent();
325
326 // Return the children of this graph.
327 const FuncGraphSet &children();
328
329 // Return the scope of this graph, scope have graph self but children not have.
330 const FuncGraphSet &scope();
331
332 // Return whether this graph is recursive.
333 bool recursive();
334
335 // Return graphs which forms a recursive loop.
336 std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
337
hash()338 std::size_t hash() const override { return std::hash<const FuncGraph *>{}(this); }
339
340 void DumpFuncGraph(const std::string &path = "./func_graph.dot");
341
342 bool operator==(const Value &other) const override {
343 if (other.isa<FuncGraph>()) {
344 return &other == this;
345 } else {
346 return false;
347 }
348 }
349 void GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count, int pos_args_input_count,
350 std::vector<AnfNodePtr> *specialized_parameter_list,
351 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
352
353 void GenerateKwParams(const FuncGraphPtr &specialized_graph,
354 const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list,
355 std::vector<AnfNodePtr> *specialized_parameter_list,
356 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
357
358 void GenerateDefaultValue(const FuncGraphPtr &specialized_graph,
359 const std::vector<AnfNodePtr> &specialized_parameter_list,
360 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
361
paramter_obj_nodes()362 const std::vector<AnfNodePtr> ¶mter_obj_nodes() const { return paramter_obj_nodes_; }
add_parameter_obj_node(const AnfNodePtr & p)363 void add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); }
364
365 std::unordered_map<std::string, ValuePtr> attrs_;
366 std::unordered_map<std::string, FuncGraphTransform> transforms_;
367 // Parameter default value.
368 std::map<std::string, AnfNodePtr> parameter_default_value_;
369 size_t seen_;
370
371 std::list<CNodePtr> GetOrderedCnodes();
372 void EraseUnusedNodeInOrder(const AnfNodePtr &n);
373 void EraseUnusedNodeInOrder();
374 void DumpCNodeList();
order_list()375 const OrderedSet<CNodePtr> &order_list() const { return order_; }
376
set_order_list(OrderedSet<CNodePtr> && order_list)377 void set_order_list(OrderedSet<CNodePtr> &&order_list) { order_ = std::move(order_list); }
378
379 // Add a cnode at the end of order list.
AppendOrderList(const CNodePtr & cnode)380 void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); }
381
382 // Prepend cnode at the front of order list.
PrependOrderList(const CNodePtr & cnode)383 void PrependOrderList(const CNodePtr &cnode) { order_.push_front(cnode); }
384
385 // Maintain cnode order list when a cnode is replaced by a new one.
386 void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
387
388 // Clear cnode order list.
ClearOrderList()389 void ClearOrderList() { order_.clear(); }
390
stub()391 bool stub() const { return stub_; }
set_stub(bool stub)392 void set_stub(bool stub) { stub_ = stub; }
set_drawer(Drawer drawer)393 static void set_drawer(Drawer drawer) { drawer_ = drawer; }
switch_input()394 std::shared_ptr<bool> switch_input() const { return switch_input_; }
set_switch_input(std::shared_ptr<bool> switch_input)395 void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; }
switch_layer_input()396 std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
set_switch_layer_input(std::shared_ptr<bool> switch_layer_input)397 void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
398 bool ContainMultiTarget() const;
stage()399 int64_t stage() { return stage_; }
set_stage(int64_t stage)400 void set_stage(int64_t stage) { stage_ = stage; }
401
dropped()402 bool dropped() const { return dropped_; }
set_dropped(bool dropped)403 void set_dropped(bool dropped) { dropped_ = dropped; }
404
bprop_hash()405 std::string bprop_hash() const { return bprop_hash_; }
set_bprop_hash(const std::string & bprop_hash)406 void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
407
modify_output()408 bool modify_output() const { return modify_output_; }
set_modify_output(bool modify_output)409 void set_modify_output(bool modify_output) { modify_output_ = modify_output; }
used_forward_nodes()410 const std::unordered_set<AnfNodePtr> &used_forward_nodes() const { return used_forward_nodes_; }
411 void set_used_forward_nodes(const std::vector<AnfNodePtr> &used_forward_nodes);
ClearUsedForwardNodes()412 void ClearUsedForwardNodes() { used_forward_nodes_.clear(); }
413
414 private:
415 // Only used for func_graph manager to control resource free.
attached_mng_cnt()416 int attached_mng_cnt() const { return attached_mng_cnt_; }
IncAttachedMngCnt()417 void IncAttachedMngCnt() { attached_mng_cnt_++; }
DecAttachedMngCnt()418 void DecAttachedMngCnt() { attached_mng_cnt_--; }
419 // Clear all info from manager.
420 void ClearAllManagerInfo();
421
422 // Graph is manipulated by manager and others.
423 friend FuncGraphManager;
424
425 // All nodes of the function.
426 AnfNodeSet nodes_;
427
428 // All value nodes of the function.
429 AnfNodeCounterMap value_nodes_;
430
431 // All func graph value nodes of the function.
432 FuncGraphCounterMap func_graphs_used_;
433
434 // All free variables of the function.
435 AnfNodeCounterMap free_variables_;
436
437 // All value nodes calling J in the function.
438 std::unordered_map<AnfNodePtr, int> j_value_nodes_;
439
440 // All user value nodes of this func graph, recording by CNode and its input's index.
441 CNodeIndexCounterMap func_graph_cnodes_index_;
442
443 // Parameters of this function.
444 std::vector<AnfNodePtr> parameters_;
445 std::vector<AnfNodePtr> paramter_obj_nodes_;
446
447 // Whether there is a *args and **kwargs, and count kwonlyargs'number.
448 bool has_vararg_;
449 bool has_kwarg_;
450 int kwonlyargs_count_;
451 // Hyper param is placed on the top graph,
452 // and positioned in the end of the param list, so we record the number to trace the position.
453 size_t hyper_param_count_;
454 // Argument input list for the graph used to generate this graph.
455 bool is_generated_;
456
457 bool is_bprop_;
458
459 // CNode that calls 'return' primitive.
460 // We use shared pointer to manage it.
461 CNodePtr return_;
462
463 // Back-ref to its manager.
464 // Hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
465 // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
466 // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
467 // In some ut test cases, they may use local FuncGraphManager in function which
468 // generating the func graph, when go outside of that function, func graph will have no
469 // FuncGraphManager. In that special case, Manage() should be called to make the func graph
470 // managed.
471 std::weak_ptr<FuncGraphManager> manager_;
472 int attached_mng_cnt_ = 0;
473
474 GraphDebugInfoPtr debug_info_;
475 void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph,
476 const std::vector<AnfNodePtr> &kwarg_keys_tuple_nodes,
477 const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes,
478 std::unordered_map<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
479
480 // CNode order which relates to origin code order.
481 OrderedSet<CNodePtr> order_;
482 bool stub_;
483 inline static Drawer drawer_ = nullptr;
484 // Design switch_input and switch_layer_input as a ptr to
485 // share between derived backpropagator and cloned graphs.
486 std::shared_ptr<bool> switch_input_;
487 std::shared_ptr<bool> switch_layer_input_;
488 int64_t stage_;
489 std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
490 abstract::AbstractBasePtrListEqual>
491 func_graph_cache_;
492
493 // If the graph was changed, it should be dropped in cache data_converter::object_map_
494 // which used by ConvertToFuncGraph.
495 bool dropped_ = false;
496 // If the graph is a bprop graph, it should has a hash of the bprop directory.
497 std::string bprop_hash_;
498
499 // If the graph is decorated by @ms_function and runs grad process in pynative mode,
500 // forward nodes used in grad graph will be added to output for holding output values.
501 bool modify_output_ = false;
502 std::unordered_set<AnfNodePtr> used_forward_nodes_;
503 };
504
NewCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & fg)505 inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
506 MS_EXCEPTION_IF_NULL(fg);
507 return fg->NewCNode(inputs);
508 }
509
510 size_t NewFgSeenGeneration();
511
512 // Find the root cnodes of a segment of cnodes.
513 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
514 // Find the leaf cnodes of a segment of cnodes.
515 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
516 } // namespace mindspore
517
518 #endif // MINDSPORE_CORE_IR_FUNC_GRAPH_H_
519