1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_H_
20 #define MINDSPORE_CORE_IR_FUNC_GRAPH_H_
21
22 #include <set>
23 #include <map>
24 #include <string>
25 #include <vector>
26 #include <list>
27 #include <unordered_map>
28 #include <memory>
29 #include <functional>
30 #include <utility>
31
32 #include "utils/hash_map.h"
33 #include "utils/hash_set.h"
34 #include "utils/ordered_set.h"
35 #include "utils/ordered_map.h"
36 #include "mindapi/base/macros.h"
37 #include "base/base_ref.h"
38 #include "base/effect_info.h"
39 #include "ir/anf.h"
40 #include "ir/manager.h"
41 #include "ir/func_graph_transform.h"
42 #include "ir/func_graph_base.h"
43 #include "abstract/abstract_value.h"
44 #include "mindspore/core/symbolic_shape/symbol_engine.h"
45
46 namespace mindspore {
47 using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
48 using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
49
50 struct CNodeIndexHasher {
operatorCNodeIndexHasher51 std::size_t operator()(const CNodeIndexPairPtr pair) const {
52 MS_EXCEPTION_IF_NULL(pair);
53 MS_EXCEPTION_IF_NULL(pair->first);
54 return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
55 }
56 };
57
58 struct CNodeIndexEqual {
operatorCNodeIndexEqual59 bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
60 if (lhs == nullptr || rhs == nullptr) {
61 return false;
62 }
63 if (lhs == rhs) {
64 return true;
65 }
66 if (lhs->first != rhs->first) {
67 return false;
68 }
69 if (lhs->second != rhs->second) {
70 return false;
71 }
72 return true;
73 }
74 };
75
76 template <typename KeyT, class CounterHash = std::hash<KeyT>, class CounterEqual = std::equal_to<KeyT>>
77 using CounterOrderedMap = OrderedMap<KeyT, int, CounterHash, CounterEqual>;
78 using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
79 using CNodeIndexCounterMap = CounterOrderedMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual>;
80
81 using FuncGraphMap = OrderedMap<FuncGraphPtr, int>;
82
83 const char FUNC_GRAPH_FLAG_IGNORE_VALUE[] = "ignore_value";
84 const char FUNC_GRAPH_FLAG_VMAP_TRANSFORMED[] = "vmap_transformed";
85 const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
86 const char FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP[] = "primal_of_bprop";
87 const char FUNC_GRAPH_FLAG_SPARSE_BPROP[] = "sparse_bprop";
88 const char FUNC_GRAPH_FLAG_NO_INLINE[] = "no_inline";
89 const char FUNC_GRAPH_FLAG_CELL_REUSE[] = "cell_reuse";
90 const char FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER[] = "lazy_inline_order";
91 const char FUNC_GRAPH_FLAG_AFTER_BLOCK[] = "after_block";
92 const char FUNC_GRAPH_FLAG_CORE[] = "core";
93 const char FUNC_GRAPH_FLAG_K_GRAPH[] = "k_graph";
94 const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel";
95 const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
96 const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute";
97 const char FUNC_GRAPH_RECOMPUTE_K_GRAPH[] = "recompute_k_graph";
98 const char FUNC_GRAPH_RECOMPUTE_GRAD_GRAPH[] = "recompute_grad_graph";
99 const char FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH[] = "not_recompute_k_graph";
100 const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline";
101 const char FUNC_GRAPH_FLAG_DUMP[] = "dump";
102 const char FUNC_GRAPH_FLAG_DYNAMIC_SHAPE[] = "dynamic_shape";
103 const char FUNC_GRAPH_FLAG_NO_RECURSIVE[] = "no_recursive";
104 const char FUNC_GRAPH_FLAG_ARGS_NO_EXPAND[] = "args_no_expand";
105 const char FUNC_GRAPH_FLAG_PROXY_GRAPH[] = "proxy_graph";
106 const char FUNC_GRAPH_FLAG_NO_CHILD_GRAPH[] = "no_child_graph";
107
108 const char kFuncGraphFlagUndetermined[] = "undeterminate";
109 const char kFuncGraphFlagBackPropEntry[] = "back_prop_entry";
110 const char kFuncGraphFlagReAutoMonad[] = "re_auto_monad";
111 const char kFuncGraphFlagRecursive[] = "recursive";
112 const char kFuncGraphFlagMetaFuncGraphBprop[] = "meta_fg_bprop";
113 const char kFuncGraphFlagAddedForwardU[] = "added_forward_u";
114
115 class MS_CORE_API FuncGraph : public FuncGraphBase, public EffectInfoHolder {
116 public:
117 using Drawer = std::function<void(const std::string &, const FuncGraphPtr &)>;
118
119 FuncGraph();
120 explicit FuncGraph(GraphDebugInfoPtr &&debug_info);
121 ~FuncGraph();
122 MS_DECLARE_PARENT(FuncGraph, FuncGraphBase);
123
124 void DoBreakLoop() override;
125
126 // Get the graph's abstract.
127 abstract::AbstractFunctionPtr abstract();
128 abstract::AbstractBasePtr ToAbstract() override;
129
130 // get function graph inputs, but parameters
131 const AnfNodePtrList get_inputs() const;
parameters()132 const AnfNodePtrList ¶meters() const { return parameters_; }
133 // Append
134 virtual ParameterPtr add_parameter();
135 ParameterPtr add_parameter(NodeDebugInfoPtr &&debug_info);
136 void add_parameter(const ParameterPtr ¶m);
append_parameter(const ParameterPtr & p)137 void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
138 // Prepend
139 virtual ParameterPtr InsertFrontParameter();
140 void InsertFrontParameter(const ParameterPtr ¶m);
PrependParameter(const ParameterPtr & p)141 void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); }
set_parameters(const AnfNodePtrList & params)142 void set_parameters(const AnfNodePtrList ¶ms) { parameters_ = params; }
set_parameters(AnfNodePtrList && params)143 void set_parameters(AnfNodePtrList &¶ms) { parameters_ = std::move(params); }
144 // Add a FV weight parameter with specific name.
145 ParameterPtr AddFvParameter(const std::string &name, const ValuePtr &default_value);
146
147 // Create a CNode with given inputs, bound to this graph.
148 virtual CNodePtr NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs);
149 virtual CNodePtr NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs);
150
151 // @deprecated
152 // To use 'CNodePtr NewCNodeWeak(AnfNodeWeakPtrList &&weak_inputs)' instead.
153 virtual CNodePtr NewCNode(AnfNodePtrList &&inputs);
154 // @deprecated
155 // To use 'CNodePtr NewCNodeWeak(const AnfNodeWeakPtrList &weak_inputs)' instead.
156 virtual CNodePtr NewCNode(const AnfNodePtrList &inputs);
157
158 CNodePtr NewCNode(const PrimitivePtr &primitive, const AnfNodePtrList &inputs);
159
160 // Create a CNode with given weak inputs, bound to this graph and push back to order list.
161 CNodePtr NewCNodeInOrderWeak(AnfNodeWeakPtrList &&weak_inputs);
162 CNodePtr NewCNodeInOrderWeak(const AnfNodeWeakPtrList &weak_inputs);
163
164 // Create a CNode with given inputs, bound to this graph and push back to order list.
165 CNodePtr NewCNodeInOrder(AnfNodePtrList &&inputs);
166 CNodePtr NewCNodeInOrder(const AnfNodePtrList &inputs = AnfNodePtrList());
167 CNodePtr NewCNodeInOrder(const PrimitivePtr &primitive, const AnfNodePtrList &inputs);
168
169 // Create a CNode with given inputs, bound to this graph and push back to front of order list.
170 CNodePtr NewCNodeInFront(const AnfNodePtrList &inputs = AnfNodePtrList());
171
172 // Create a CNode with given inputs, put it to order list before the position node.
173 CNodePtr NewCNodeBefore(const AnfNodePtr &position, const AnfNodePtrList &inputs);
174
175 // Create a CNode with given inputs, put it to order list after the position node.
176 CNodePtr NewCNodeAfter(const AnfNodePtr &position, const AnfNodePtrList &inputs);
177
178 // Functions for handling variable argument, keyword-only arguments and variable keyword argument.
179 AnfNodePtr GetDefaultValueByName(const std::string &name);
set_param_default_value(const std::string & name,const AnfNodePtr & node)180 void set_param_default_value(const std::string &name, const AnfNodePtr &node) {
181 parameter_default_value_[name] = node;
182 }
183 void SetDefaultValues(const std::vector<std::string> &name_list, const AnfNodePtrList &value_list);
184 void ClearDefaultValues();
185 size_t GetDefaultValueCount();
parameter_default_value()186 std::map<std::string, AnfNodePtr> ¶meter_default_value() { return parameter_default_value_; }
set_has_vararg(bool has_)187 void set_has_vararg(bool has_) { has_vararg_ = has_; }
has_vararg()188 bool has_vararg() const { return has_vararg_; }
189 // Parameters are ordered as: Positional Parameters, Kwonlyargs, *Varargs, **Kwargs, HyperParam;
190 AnfNodePtr GetVariableArgParameter();
191 std::string GetVariableArgName();
set_has_kwarg(bool has_)192 void set_has_kwarg(bool has_) { has_kwarg_ = has_; }
has_kwarg()193 bool has_kwarg() const { return has_kwarg_; }
set_kwonlyargs_count(int count)194 void set_kwonlyargs_count(int count) { kw_only_args_count_ = count; }
kwonlyargs_count()195 int kwonlyargs_count() const { return kw_only_args_count_; }
196 AnfNodePtr GetVariableKwargParameter();
197 std::string GetVariableKwargName();
198 AnfNodePtrList GetKwOnlyArgsParameters();
set_fv_param_count(size_t count)199 void set_fv_param_count(size_t count) { fv_param_count_ = count; }
fv_param_count()200 size_t fv_param_count() const { return fv_param_count_; }
201 int GetPositionalArgsCount() const;
202 AnfNodePtr GetParameterByName(const std::string &name);
203 bool NeedGenerate(const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list);
204 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list);
set_is_generate(bool generated)205 void set_is_generate(bool generated) { is_generated_ = generated; }
is_generated()206 bool is_generated() const { return is_generated_; }
207
attrs()208 mindspore::HashMap<std::string, ValuePtr> &attrs() { return attrs_; }
set_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)209 void set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
210 for (auto &attr : attrs) {
211 attrs_[attr.first] = attr.second;
212 }
213 }
214 bool has_flag(const std::string &key) const;
set_flag(const std::string & key,bool flag)215 void set_flag(const std::string &key, bool flag) { attrs_[key] = MakeValue(flag); }
erase_flag(const std::string & key)216 void erase_flag(const std::string &key) { (void)attrs_.erase(key); }
217
218 bool has_attr(const std::string &key) const;
219 ValuePtr get_attr(const std::string &key) const;
set_attr(const std::string & key,const ValuePtr & value)220 void set_attr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
221
transforms()222 mindspore::HashMap<std::string, FuncGraphTransform> &transforms() { return transforms_; }
set_transforms(const mindspore::HashMap<std::string,FuncGraphTransform> & transforms)223 void set_transforms(const mindspore::HashMap<std::string, FuncGraphTransform> &transforms) {
224 transforms_ = transforms;
225 }
226
227 // Return the graph's output, or nullptr if not yet deduced.
228 AnfNodePtr output() const;
229 void set_output(const AnfNodePtr &value, bool force_new_ret = false);
230
get_return()231 CNodePtr get_return() const { return return_.lock(); }
return_node()232 const CNodePtr return_node() const { return return_.lock(); }
set_return(const CNodePtr & cnode)233 void set_return(const CNodePtr &cnode) {
234 return_owner_ = cnode;
235 return_ = CNodeWeakPtr(cnode);
236 }
ResetReturnOwner()237 void ResetReturnOwner() { return_owner_.reset(); }
238
239 const std::list<AnfNodePtr> &own_nodes() const;
240 void AddOwnNode(const AnfNodePtr &node);
241 void AddOwnNode(const AnfNodePtrList &nodes);
242 void AddOwnNode(const AnfNodeWeakPtrList &weak_nodes);
243 void RemoveOwnNode(const AnfNodePtr &node);
244 void ResetOwnNodes();
245
manager()246 FuncGraphManagerPtr manager() const { return manager_.lock(); }
set_manager(const FuncGraphManagerPtr & m)247 void set_manager(const FuncGraphManagerPtr &m) { manager_ = std::weak_ptr<FuncGraphManager>(m); }
248
249 std::string ToString() const override;
250 GraphDebugInfoPtr debug_info();
set_debug_info(const GraphDebugInfoPtr & info)251 void set_debug_info(const GraphDebugInfoPtr &info) {
252 if (info == nullptr) {
253 MS_LOG(INTERNAL_EXCEPTION) << "Graph set null debug info";
254 }
255 this->debug_info_ = info;
256 }
257 // Get all nodes belonging to this func graph.
258 const AnfNodeSet &nodes() const;
259 const AnfNodeSet &switch_nodes() const;
260 void CopyNodes(const FuncGraphPtr &source);
261 void ClearNodes();
262 void AddNode(const AnfNodePtr &node);
263 void DropNode(const AnfNodePtr &node);
264
265 // Get all value_nodes belonging to this func graph.
266 const AnfNodeCounterMap &value_nodes() const;
267 void CopyValueNodes(const FuncGraphPtr &source);
268 void ClearValueNodes();
269 void AddValueNode(const AnfNodePtr &node, int count = 1);
270 void DropValueNode(const AnfNodePtr &node);
271
272 // Get all free vars directly used in this func graph.
273 const AnfNodeCounterMap &free_variables() const;
274 void CopyFreeVariables(const FuncGraphPtr &source);
275 void ClearFreeVariables();
276 bool AddFreeVariable(const AnfNodePtr &node, int count = 1);
277 bool DropFreeVariable(const AnfNodePtr &node);
278
279 // Get all vars required by this func graph.
280 const BaseRefCounterMap &free_variables_total();
281
282 // Return the set of graphs free_variables_total belong to.
283 AnfNodePtrList free_variables_nodes();
284
285 // Get all vars that are func graphs
286 std::vector<FuncGraphPtr> free_variables_func_graphs();
287
288 // Get all value nodes of func graph directly used by this func graph.
289 const FuncGraphCounterMap &func_graphs_used() const;
290 void CopyFuncGraphsUsed(const FuncGraphPtr &source);
291 void ClearFuncGraphsUsed();
292 bool AddFuncGraphUsed(const FuncGraphPtr &fg, int count = 1);
293 bool DropFuncGraphUsed(const FuncGraphPtr &fg);
294
295 // Get all value nodes in the inputs of MetaFgPrim directly used by this func graph.
296 const mindspore::HashMap<AnfNodePtr, int> &meta_fg_prim_value_nodes() const;
297 void CopyMetaFgPrimValueNodes(const FuncGraphPtr &source);
298 void ClearMetaFgPrimValueNodes();
299 void AddMetaFgPrimValueNode(const AnfNodePtr &value_node, int count = 1);
300 void DropMetaFgPrimValueNode(const AnfNodePtr &value_node);
301
302 // Get all func graphs nested used by this func graph.
303 const FuncGraphSet &func_graphs_used_total();
304
305 // Get all user value nodes of this func graph, by CNode and its input's index.
306 const CNodeIndexCounterMap &func_graph_cnodes_index() const;
307 void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
308 void ClearFuncGraphCNodesIndex();
309 void AddFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair, int count = 1);
310 void DropFuncGraphCNodeIndex(const CNodeIndexPairPtr &pair);
311
312 // Return the parent of this graph.
313 FuncGraphPtr parent();
314
315 // Return the children of this graph.
316 const FuncGraphSet &children();
317
318 // Return the scope of this graph, scope have graph self but children not have.
319 const FuncGraphSet &scope();
320
321 // Return whether this graph is recursive.
322 bool recursive();
323
324 // Return graphs which forms a recursive loop.
325 std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs();
326
hash()327 std::size_t hash() const override { return PointerHash<FuncGraph>{}(this); }
328
329 bool operator==(const Value &other) const override {
330 if (other.isa<FuncGraph>()) {
331 return &other == this;
332 } else {
333 return false;
334 }
335 }
336 void GenerateVarParams(const FuncGraphPtr &specialized_graph, int variable_args_count, int pos_args_input_count,
337 AnfNodePtrList *specialized_parameter_list,
338 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
339
340 void GenerateKwParams(const FuncGraphPtr &specialized_graph,
341 const std::vector<abstract::AbstractKeywordArgPtr> &kwarg_list, int pos_args_input_count,
342 AnfNodePtrList *specialized_parameter_list,
343 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
344
345 void GenerateDefaultValue(const FuncGraphPtr &specialized_graph, const AnfNodePtrList &specialized_parameter_list,
346 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
347
parameter_obj_nodes()348 const AnfNodePtrList ¶meter_obj_nodes() const { return parameter_obj_nodes_; }
add_parameter_obj_node(const AnfNodePtr & p)349 void add_parameter_obj_node(const AnfNodePtr &p) { parameter_obj_nodes_.push_back(p); }
350
351 mindspore::HashMap<std::string, ValuePtr> attrs_;
352 mindspore::HashMap<std::string, FuncGraphTransform> transforms_;
353 // Parameter default value.
354 std::map<std::string, AnfNodePtr> parameter_default_value_;
355
356 SeenNum seen_{0};
357 SeenNum extra_seen_{0};
358
359 std::list<CNodePtr> GetOrderedCnodes();
360 void EraseUnusedNodeInOrder(const AnfNodePtr &node);
361 void EraseUnusedNodeInOrder();
362 void DumpCNodeList();
order_list()363 const std::list<CNodeWeakPtr> &order_list() const { return order_; }
364
set_order_list(std::list<CNodeWeakPtr> && order_list)365 void set_order_list(std::list<CNodeWeakPtr> &&order_list) { order_ = std::move(order_list); }
366
367 // Add a CNode at the end of order list.
AppendOrderList(const CNodePtr & cnode)368 void AppendOrderList(const CNodePtr &cnode) { (void)order_.emplace_back(CNodeWeakPtr(cnode)); }
369
370 // Prepend CNode at the front of order list.
PrependOrderList(const CNodePtr & cnode)371 void PrependOrderList(const CNodePtr &cnode) { (void)order_.emplace_front(CNodeWeakPtr(cnode)); }
372
373 // Maintain CNode order list when a CNode is replaced by a new one.
374 void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
375
376 // Clear CNode order list.
ClearOrderList()377 void ClearOrderList() { order_.clear(); }
378
stub()379 bool stub() const { return stub_; }
set_stub(bool stub)380 void set_stub(bool stub) { stub_ = stub; }
381
indirect()382 std::shared_ptr<bool> indirect() {
383 // Lazy initialization.
384 if (!indirect_) {
385 indirect_ = std::make_shared<bool>(false);
386 }
387 return indirect_;
388 }
set_indirect(std::shared_ptr<bool> indirect)389 void set_indirect(std::shared_ptr<bool> indirect) { indirect_ = indirect; }
390
391 void SetMultiTarget() const;
exist_multi_target()392 bool exist_multi_target() const { return exist_multi_target_; }
set_exist_multi_target(bool exist_multi_target)393 void set_exist_multi_target(bool exist_multi_target) { exist_multi_target_ = exist_multi_target; }
stage()394 int64_t stage() const { return stage_; }
set_stage(int64_t stage)395 void set_stage(int64_t stage) { stage_ = stage; }
segment()396 int64_t segment() const { return segment_; }
set_segment(int64_t segment)397 void set_segment(int64_t segment) { segment_ = segment; }
dynamic_shape()398 bool dynamic_shape() { return dynamic_shape_; }
set_dynamic_shape(bool dynamic_shape)399 void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
400
dropped()401 bool dropped() const { return dropped_; }
set_dropped(bool dropped)402 void set_dropped(bool dropped) { dropped_ = dropped; }
403
bprop_hash()404 std::string bprop_hash() const { return bprop_hash_; }
set_bprop_hash(const std::string & bprop_hash)405 void set_bprop_hash(const std::string &bprop_hash) { bprop_hash_ = bprop_hash; }
406
bprop_filepath()407 std::string bprop_filepath() const { return bprop_filepath_; }
set_bprop_filepath(const std::string & bprop_filepath)408 void set_bprop_filepath(const std::string &bprop_filepath) { bprop_filepath_ = bprop_filepath; }
409
modify_output()410 bool modify_output() const { return modify_output_; }
set_modify_output(bool modify_output)411 void set_modify_output(bool modify_output) { modify_output_ = modify_output; }
used_forward_nodes()412 const mindspore::OrderedSet<AnfNodePtr> &used_forward_nodes() const { return used_forward_nodes_; }
413 void set_used_forward_nodes(const AnfNodePtrList &used_forward_nodes);
ClearUsedForwardNodes()414 void ClearUsedForwardNodes() { used_forward_nodes_.clear(); }
415
is_tensor_condition_branch()416 bool is_tensor_condition_branch() const { return is_tensor_condition_branch_; }
set_is_tensor_condition_branch(bool is_tensor_condition_branch)417 void set_is_tensor_condition_branch(bool is_tensor_condition_branch) {
418 is_tensor_condition_branch_ = is_tensor_condition_branch;
419 }
420
421 /// \brief Topological sort a graph from the given end node.
422 ///
423 /// \param[in] node The end node of the graph to be sorted.
424 ///
425 /// \return The sorted nodes.
426 static AnfNodePtrList TopoSort(const AnfNodePtr &node);
427
set_python_obj(const ValuePtr & python_obj)428 void set_python_obj(const ValuePtr &python_obj) { python_obj_ = python_obj; }
python_obj()429 ValuePtr python_obj() const { return python_obj_; }
430
phase()431 const std::string &phase() const { return phase_; }
432
set_symbol_engine(const SymbolEnginePtr & se)433 void set_symbol_engine(const SymbolEnginePtr &se) { symbol_engine_ = se; }
symbol_engine()434 const SymbolEnginePtr &symbol_engine() const { return symbol_engine_; }
435
436 // Only used for func_graph manager to control resource free.
attached_mng_cnt()437 int attached_mng_cnt() const { return attached_mng_cnt_; }
438
439 // Reserve the func graph, not to release in manager.
set_reserved(bool reserved)440 void set_reserved(bool reserved) { reserved_ = reserved; }
reserved()441 bool reserved() const { return reserved_; }
442
443 private:
444 // Only used for func_graph manager to control resource free.
IncAttachedMngCnt()445 void IncAttachedMngCnt() { attached_mng_cnt_++; }
DecAttachedMngCnt()446 void DecAttachedMngCnt() { attached_mng_cnt_--; }
447 // Clear all info from manager.
448 void ClearAllResource();
449
450 // Graph is manipulated by manager and others.
451 friend FuncGraphManager;
452
453 // All nodes of the function.
454 AnfNodeSet nodes_;
455
456 // All switch nodes of the function.
457 AnfNodeSet switch_nodes_;
458
459 // All value nodes of the function.
460 AnfNodeCounterMap value_nodes_;
461
462 // All func graph value nodes of the function.
463 FuncGraphCounterMap func_graphs_used_;
464
465 // All free variables of the function.
466 AnfNodeCounterMap free_variables_;
467
468 // All value nodes calling MetaFgPrim in the function.
469 mindspore::HashMap<AnfNodePtr, int> meta_fg_prim_value_nodes_;
470
471 // All user value nodes of this func graph, recording by CNode and its input's index.
472 CNodeIndexCounterMap func_graph_cnodes_index_;
473
474 // Parameters of this function.
475 AnfNodePtrList parameters_;
476 AnfNodePtrList parameter_obj_nodes_;
477
478 // Whether there is a *args and **kwargs, and count kw_only_args'number.
479 bool has_vararg_;
480 bool has_kwarg_;
481 bool exist_multi_target_;
482 int kw_only_args_count_;
483 // Hyper param is used as free variable and placed on the top graph.
484 // and positioned in the end of the param list, so we record the number to trace the position.
485 size_t fv_param_count_;
486 // Argument input list for the graph used to generate this graph.
487 bool is_generated_;
488 // CNode that calls 'return' primitive.
489 // We use shared pointer to manage it.
490 CNodeWeakPtr return_;
491 // Before release all func graphs in Manager, reset the owner firstly.
492 CNodePtr return_owner_;
493
494 // Back-ref to its manager.
495 // Hold a weak ref to FuncGraphManager as FuncGraphManager also hold many ref to FuncGraph.
496 // Otherwise, FuncGraph and FuncGraphManager will make a reference cycles.
497 // Notes: Normally, there will be a global FuncGraphManager, it will hold all FuncGraphs.
498 // In some ut test cases, they may use local FuncGraphManager in function which
499 // generating the func graph, when go outside of that function, func graph will have no
500 // FuncGraphManager. In that special case, Manage() should be called to make the func graph
501 // managed.
502 std::weak_ptr<FuncGraphManager> manager_;
503 int attached_mng_cnt_{0};
504
505 GraphDebugInfoPtr debug_info_;
506 void GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, const AnfNodePtrList &kwarg_keys_tuple_nodes,
507 const AnfNodePtrList &kwarg_values_tuple_nodes,
508 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl_nodes) const;
509
510 // CNode order which relates to origin code order.
511 std::list<CNodeWeakPtr> order_;
512 bool stub_;
513
514 // The graph is used as some input of Switch, SwitchLayer, or Partial.
515 std::shared_ptr<bool> indirect_;
516
517 int64_t stage_;
518 int64_t segment_;
519 bool dynamic_shape_ = false;
520 std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
521 abstract::AbstractBasePtrListEqual>
522 func_graph_cache_;
523
524 // If the graph was changed, it should be dropped in cache data_converter::object_map_
525 // which used by ConvertToFuncGraph.
526 bool dropped_{false};
527 // If the graph is a bprop graph, it should has a hash of the bprop function.
528 std::string bprop_hash_;
529 // If the graph is a bprop graph, it should has a filepath of the bprop function.
530 std::string bprop_filepath_;
531
532 // If the graph is decorated with @jit and runs grad process in pynative mode,
533 // forward nodes used in grad graph will be added to output for holding output values.
534 bool modify_output_{false};
535 mindspore::OrderedSet<AnfNodePtr> used_forward_nodes_;
536 // If the func_graph is input of switch node, and the condition of switch is AbstractTensor, need set true.
537 bool is_tensor_condition_branch_{false};
538 // Corresponding python obj.
539 ValuePtr python_obj_{nullptr};
540 std::string phase_;
541 // Own all nodes in the func graph.
542 std::list<AnfNodePtr> own_nodes_;
543 // the manager of symbolic shape's symbols and operations.
544 SymbolEnginePtr symbol_engine_;
545 // Reserve the func graph, not to release in manager.
546 bool reserved_{false};
547 };
548
NewCNode(const AnfNodePtrList & inputs,const FuncGraphPtr & fg)549 inline CNodePtr NewCNode(const AnfNodePtrList &inputs, const FuncGraphPtr &fg) {
550 MS_EXCEPTION_IF_NULL(fg);
551 return fg->NewCNode(inputs);
552 }
553
NewCNode(AnfNodePtrList && inputs,const FuncGraphPtr & fg)554 inline CNodePtr NewCNode(AnfNodePtrList &&inputs, const FuncGraphPtr &fg) {
555 MS_EXCEPTION_IF_NULL(fg);
556 return fg->NewCNode(std::move(inputs));
557 }
558
559 MS_CORE_API SeenNum NewFgSeenGeneration();
560
561 // Find the root cnodes of a segment of cnodes.
562 std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
563 // Find the leaf cnodes of a segment of cnodes.
564 std::shared_ptr<OrderedSet<CNodePtr>> FindLeaves(const std::vector<CNodePtr> &segment);
565 } // namespace mindspore
566
567 #endif // MINDSPORE_CORE_IR_FUNC_GRAPH_H_
568