• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/func_graph_cloner.h"
18 #include <algorithm>
19 #include <set>
20 
21 #include "abstract/abstract_function.h"
22 #include "ir/graph_utils.h"
23 #include "ir/manager.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "utils/convert_utils_base.h"
27 #include "utils/log_adapter.h"
28 #include "utils/ms_context.h"
29 #include "utils/parallel_node_check.h"
30 #include "utils/profile.h"
31 #include "utils/trace_base.h"
32 
33 // namespace to support intermediate representation definition
34 namespace mindspore {
35 namespace {
CloneNodeDebugInfo(const DebugInfoPtr & debug_info,const TraceInfoPtr & relation)36 NodeDebugInfoPtr CloneNodeDebugInfo(const DebugInfoPtr &debug_info, const TraceInfoPtr &relation) {
37   auto trace_info = relation->clone();
38   trace_info->set_debug_info(debug_info);
39   return std::make_shared<NodeDebugInfo>(std::move(trace_info));
40 }
41 
CloneNodeDebugInfo(const NodeDebugInfoPtr & debug_info)42 NodeDebugInfoPtr CloneNodeDebugInfo(const NodeDebugInfoPtr &debug_info) {
43   auto trace_info = std::make_shared<TraceCopy>(debug_info);
44   return std::make_shared<NodeDebugInfo>(std::move(trace_info));
45 }
46 
CloneGraphDebugInfo(const GraphDebugInfoPtr & debug_info,const TraceInfoPtr & relation)47 GraphDebugInfoPtr CloneGraphDebugInfo(const GraphDebugInfoPtr &debug_info, const TraceInfoPtr &relation) {
48   auto trace_info = relation->clone();
49   trace_info->set_debug_info(debug_info);
50   return std::make_shared<GraphDebugInfo>(std::move(trace_info));
51 }
52 }  // namespace
53 
Cloner(const FuncGraphVector & func_graphs,bool clone_all_valuenodes,bool clone_all_child_graphs,bool clone_all_used_graphs,const TraceInfoPtr & relation,const TraceInfoPtr & target_relation)54 Cloner::Cloner(const FuncGraphVector &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
55                bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
56     : clone_all_valuenodes_(clone_all_valuenodes),
57       clone_all_child_graphs_(clone_all_child_graphs),
58       clone_all_used_graphs_(clone_all_used_graphs),
59       relation_(relation),
60       target_relation_(target_relation == nullptr ? relation : target_relation),
61       scope_(kDefaultScope),
62       type_(kBasic) {
63   for (auto &func_graph : func_graphs) {
64     AddClone(func_graph);
65   }
66 }
67 
AddClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & params,CloneType type)68 void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
69                       const AnfNodePtrList &params, CloneType type) {
70   if (func_graph != nullptr) {
71     (void)todo_.emplace_back(CloneInfo{func_graph, target_func_graph, params});
72     type_ = type;
73   }
74 }
75 
CloneNode(const AnfNodePtr & node,const FuncGraphPtr & target)76 void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
77   MS_EXCEPTION_IF_NULL(node);
78   if (replicated_node_.find(node) != replicated_node_.end()) {
79     return;
80   }
81   if (node->isa<CNode>()) {
82     CloneCNodeWithoutInputs(node, target);
83   } else if (node->isa<Parameter>()) {
84     CloneParameter(node, target, false);
85   }
86 }
87 
CloneParameter(const AnfNodePtr & node,const FuncGraphPtr & target,bool is_add)88 void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
89   MS_EXCEPTION_IF_NULL(node);
90   MS_EXCEPTION_IF_NULL(target);
91   auto old_param = node->cast_ptr<Parameter>();
92   MS_EXCEPTION_IF_NULL(old_param);
93   auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
94   auto new_param = (is_add ? target->add_parameter(std::move(debug_info))
95                            : std::make_shared<Parameter>(target, std::move(debug_info)));
96   if (preset_abstract()) {
97     new_param->set_abstract(old_param->abstract());
98   }
99   new_param->set_name(old_param->name());
100   if (old_param->has_default()) {
101     // Default parameter can be shared since it is readonly.
102     new_param->set_default_param(old_param->default_param());
103   }
104   new_param->set_is_top_graph_param(old_param->is_top_graph_param());
105   ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
106   new_param->set_scope(scope);
107   replicated_node_[node] = std::move(new_param);
108 }
109 
110 // Create a new empty CNode for old one, and bind them.
111 // Also see LinkCNodeEdges().
CloneCNodeWithoutInputs(const AnfNodePtr & node,const FuncGraphPtr & target)112 void Cloner::CloneCNodeWithoutInputs(const AnfNodePtr &node, const FuncGraphPtr &target) {
113   MS_EXCEPTION_IF_NULL(node);
114   MS_EXCEPTION_IF_NULL(target);
115   auto old_node = node->cast<CNodePtr>();
116   AnfNodeWeakPtrList inputs;
117   inputs.reserve(old_node->size());
118   DebugInfoPtr debug_info;
119   if (this->update_info() != nullptr && this->update_info()->debug_info_ != nullptr) {
120     debug_info = this->update_info()->debug_info_;
121   } else {
122     debug_info = node->debug_info();
123   }
124 
125   auto cloned_debug_info = CloneNodeDebugInfo(debug_info, relation_);
126   CNodePtr new_node = std::make_shared<CNode>(std::move(inputs), target, std::move(cloned_debug_info));
127   if (inline_call_node_ != nullptr) {
128     MS_LOG(DEBUG) << "inline_call_node_: " << inline_call_node_ << "/" << inline_call_node_->DebugString()
129                   << ", new_node: " << new_node << "/" << new_node->DebugString();
130     UpdateInlineCNodeDebugInfo(inline_call_node_, new_node);
131   } else {
132     // Synchronize callers' shadow debug infos.
133     auto &new_shadow_debug_infos = new_node->debug_info()->shadow_debug_infos_map();
134     const auto &old_shadow_debug_infos = debug_info->shadow_debug_infos_map();
135     new_shadow_debug_infos.insert(old_shadow_debug_infos.cbegin(), old_shadow_debug_infos.cend());
136   }
137   new_node->CloneCNodeInfo(old_node);
138   // Copy to target graph
139   if (new_node->forward().first != nullptr) {
140     target->set_used_forward_nodes({new_node});
141   }
142   ScopePtr scope;
143   if (this->update_info() != nullptr && this->update_info()->scope_ != nullptr) {
144     scope = this->update_info()->scope_;
145   } else {
146     scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
147   }
148   new_node->set_scope(scope);
149   replicated_node_[node] = std::move(new_node);
150 }
151 
CloneValueNode(const AnfNodePtr & node)152 void Cloner::CloneValueNode(const AnfNodePtr &node) {
153   MS_EXCEPTION_IF_NULL(node);
154   auto value_node = node->cast_ptr<ValueNode>();
155   MS_EXCEPTION_IF_NULL(value_node);
156   auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
157   ValueNodePtr new_const = NewValueNode(GetValueNode(node), std::move(debug_info));
158   ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
159   new_const->set_scope(scope);
160   if (preset_abstract()) {
161     new_const->set_abstract(node->abstract());
162   }
163   new_const->set_has_new_value(value_node->has_new_value());
164   replicated_node_[node] = std::move(new_const);
165 }
166 
CloneFuncGraphValueNode(const AnfNodePtr & node,const FuncGraphPtr & target)167 void Cloner::CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
168   MS_EXCEPTION_IF_NULL(node);
169   MS_EXCEPTION_IF_NULL(target);
170   auto value_node = node->cast_ptr<ValueNode>();
171   MS_EXCEPTION_IF_NULL(value_node);
172   auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
173   ValueNodePtr new_const = NewValueNode(target, std::move(debug_info));
174   ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
175   new_const->set_scope(scope);
176   if (preset_abstract()) {
177     new_const->set_abstract(node->abstract());
178   }
179   new_const->set_has_new_value(value_node->has_new_value());
180   replicated_node_[node] = std::move(new_const);
181 }
182 
CloneValueNodes(const FuncGraphPtr & func_graph)183 void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
184   MS_EXCEPTION_IF_NULL(func_graph);
185   if (!clone_all_valuenodes_) {
186     return;
187   }
188   auto &value_nodes = func_graph->value_nodes();
189   for (auto &value_node : value_nodes) {
190     auto &old_node = value_node.first;
191     if (replicated_node_.find(old_node) == replicated_node_.end()) {
192       CloneValueNode(old_node);
193     }
194   }
195 }
196 
AddChildGraphs(const FuncGraphPtr & func_graph)197 void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
198   MS_EXCEPTION_IF_NULL(func_graph);
199   MS_EXCEPTION_IF_NULL(manager_);
200   if (!clone_all_child_graphs_) {
201     return;
202   }
203   // The graph marked 'no_child_graph' has no child graph.
204   if (func_graph->has_flag(FUNC_GRAPH_FLAG_NO_CHILD_GRAPH)) {
205     return;
206   }
207   auto &scopes = manager_->scopes(func_graph);
208   std::set<const FuncGraph *> memo;
209   for (auto &graph : scopes) {
210     // Avoid to insert duplicate function.
211     if (graph == func_graph || !memo.emplace(graph.get()).second) {
212       continue;
213     }
214     (void)todo_.emplace_back(CloneInfo{graph, nullptr, {}});
215   }
216 }
217 
AddTotalGraphs(const FuncGraphPtr & func_graph)218 void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
219   MS_EXCEPTION_IF_NULL(func_graph);
220   if (!clone_all_used_graphs_) {
221     return;
222   }
223   std::set<const FuncGraph *> memo;
224   auto &used = func_graph->func_graphs_used();
225   for (auto &fg : used) {
226     // Avoid to insert duplicate function.
227     if (!memo.emplace(fg.first.get()).second) {
228       continue;
229     }
230     (void)todo_.emplace_back(CloneInfo{fg.first, nullptr, {}});
231   }
232 }
233 
CloneFuncGraphDefaultValues(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)234 void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
235   MS_EXCEPTION_IF_NULL(func_graph);
236   MS_EXCEPTION_IF_NULL(target_func_graph);
237   for (auto &item : func_graph->parameter_default_value()) {
238     auto nodes = TopoSort(item.second, SuccDeeperSimple);
239     for (auto &node : nodes) {
240       MS_EXCEPTION_IF_NULL(node);
241       if (node->isa<CNode>()) {
242         CloneNode(node, target_func_graph);
243       } else if (node->isa<ValueNode>()) {
244         CloneValueNode(node);
245       }
246     }
247   }
248 }
249 
CloneFuncGraphValueNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)250 void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
251   MS_EXCEPTION_IF_NULL(func_graph);
252   MS_EXCEPTION_IF_NULL(target_func_graph);
253 
254   target_func_graph->set_stage(func_graph->stage());
255   target_func_graph->set_segment(func_graph->segment());
256   auto &old_return = func_graph->return_node();
257   if (old_return != nullptr) {
258     auto iter = replicated_node_.find(old_return);
259     if (iter == replicated_node_.end()) {
260       MS_LOG(INTERNAL_EXCEPTION) << "Can't find replicate node for return.";
261     }
262     MS_EXCEPTION_IF_NULL(iter->second);
263     auto return_node = iter->second->cast<CNodePtr>();
264     MS_EXCEPTION_IF_NULL(return_node);
265     target_func_graph->set_return(return_node);
266   } else {
267     MS_LOG(ERROR) << "Has no return node, func_graph: " << func_graph << "/" << func_graph->ToString();
268   }
269 
270   auto &cnodes = func_graph->func_graph_cnodes_index();
271   for (auto &cnode : cnodes) {
272     MS_EXCEPTION_IF_NULL(cnode.first);
273     MS_EXCEPTION_IF_NULL(cnode.first->first);
274     auto user_cnode = cnode.first->first->cast_ptr<CNode>();
275     MS_EXCEPTION_IF_NULL(user_cnode);
276     const auto &valuenode = user_cnode->input(IntToSize(cnode.first->second));
277     if (valuenode == nullptr) {
278       continue;
279     }
280     CloneFuncGraphValueNode(valuenode, target_func_graph);
281   }
282 }
283 
InlineCloneParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & params)284 void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params) {
285   MS_EXCEPTION_IF_NULL(func_graph);
286   auto &old_params = func_graph->parameters();
287   if (old_params.size() != params.size()) {
288     MS_INTERNAL_EXCEPTION(TypeError) << "Origin params size[" << old_params.size() << "], inline params size["
289                                      << params.size() << "]";
290   }
291   for (size_t i = 0; i < old_params.size(); ++i) {
292     replicated_node_[old_params[i]] = params[i];
293   }
294 }
295 
SetFuncGraphInfo(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph) const296 void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) const {
297   MS_EXCEPTION_IF_NULL(func_graph);
298   MS_EXCEPTION_IF_NULL(target_func_graph);
299   target_func_graph->set_attrs(func_graph->attrs());
300   target_func_graph->set_transforms(func_graph->transforms());
301   target_func_graph->set_has_vararg(func_graph->has_vararg());
302   target_func_graph->set_has_kwarg(func_graph->has_kwarg());
303   target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
304   target_func_graph->set_fv_param_count(func_graph->fv_param_count());
305   target_func_graph->set_is_generate(func_graph->is_generated());
306   target_func_graph->set_stub(func_graph->stub());
307   target_func_graph->set_indirect(func_graph->indirect());
308   target_func_graph->set_python_obj(func_graph->python_obj());
309   target_func_graph->set_has_side_effect_node(func_graph->has_side_effect_node());
310 }
311 
CloneParameters(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)312 void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
313   MS_EXCEPTION_IF_NULL(func_graph);
314   MS_EXCEPTION_IF_NULL(target_func_graph);
315   auto &params = func_graph->parameters();
316   for (auto &param : params) {
317     CloneParameter(param, target_func_graph, true);
318   }
319 }
320 
GenParameters(const FuncGraphPtr & func_graph)321 void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
322   MS_EXCEPTION_IF_NULL(func_graph);
323   auto &free_vars = manager_->free_variables_total();
324   auto iter = free_vars.find(func_graph);
325   if (iter == free_vars.end()) {
326     return;
327   }
328 
329   CloneInfo item = todo_.back();
330   auto lift_top_func_graph = item.origin;
331   for (auto &fv_map : iter->second) {
332     auto &free_var = fv_map.first;
333     if (!utils::isa<AnfNodePtr>(free_var)) {
334       continue;
335     }
336     auto free_var_node = utils::cast<AnfNodePtr>(free_var);
337     // Don't lift weight parameter to top func_graph.
338     if (IsLiftTopFuncGraph(func_graph) && free_var_node->isa<Parameter>()) {
339       auto free_var_param = free_var_node->cast_ptr<Parameter>();
340       if (free_var_param->has_default()) {
341         MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->DebugString()
342                       << " for top_func_graph: " << lift_top_func_graph->ToString();
343         continue;
344       }
345     }
346     auto &replicated_node = replicated_map_node_[func_graph];
347     if (replicated_node.find(free_var_node) != replicated_node.end()) {
348       MS_LOG(DEBUG) << "Param exists: " << free_var_node->DebugString()
349                     << " for func_graph: " << func_graph->ToString();
350       continue;
351     }
352 
353     MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
354     auto fv_parameter = AddParameter(func_graph, free_var_node);
355     fv_parameter->set_user_data<bool>("lifted_from_fv", std::make_shared<bool>(true));
356     auto &fg_params = replicated_func_graph_params_[func_graph];
357     (void)fg_params.emplace_back(fv_parameter);
358   }
359 }
360 
CloneParameter(const ParameterPtr & param,const AnfNodePtr & node) const361 void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) const {
362   MS_EXCEPTION_IF_NULL(param);
363   MS_EXCEPTION_IF_NULL(node);
364   if (preset_abstract()) {
365     param->set_abstract(node->abstract());
366   }
367   if (node->isa<Parameter>()) {
368     auto old_param = node->cast_ptr<Parameter>();
369     if (old_param->has_default()) {
370       // Default parameter can be shared since it is readonly.
371       param->set_default_param(old_param->default_param());
372     }
373     param->set_name(old_param->name());
374     constexpr char lifted_user_data_key[] = "lifted_from_fv";
375     auto lifted = param->user_data<bool>(lifted_user_data_key);
376     if (lifted != nullptr && *lifted) {
377       param->set_user_data<bool>(lifted_user_data_key, std::make_shared<bool>(true));
378     }
379   }
380 }
381 
AddParameter(const FuncGraphPtr & func_graph,const AnfNodePtr & node,bool is_add)382 ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
383   MS_EXCEPTION_IF_NULL(func_graph);
384   MS_EXCEPTION_IF_NULL(node);
385   auto debug_info = CloneNodeDebugInfo(node->debug_info());
386   ParameterPtr param = std::make_shared<Parameter>(func_graph, std::move(debug_info));
387   CloneParameter(param, node);
388   if (is_add) {
389     func_graph->add_parameter(param);
390   }
391   replicated_node_[param] = node;
392   replicated_map_node_[func_graph][node] = param;
393   return param;
394 }
395 
396 namespace {
FilterMonadInput(const AnfNodeWeakPtrList & old_inputs,AnfNodeWeakPtrList * new_inputs,AnfNodePtr * possible_u_monad,AnfNodePtr * possible_io_monad)397 bool FilterMonadInput(const AnfNodeWeakPtrList &old_inputs, AnfNodeWeakPtrList *new_inputs,
398                       AnfNodePtr *possible_u_monad, AnfNodePtr *possible_io_monad) {
399   AnfNodePtr local_u_monad = nullptr;
400   AnfNodePtr local_io_monad = nullptr;
401   for (const auto &weak_input : old_inputs) {
402     auto input = weak_input.lock();
403     MS_EXCEPTION_IF_NULL(input);
404     // Should be only one U Monad input.
405     if (HasAbstractUMonad(input)) {
406       if (local_u_monad != nullptr) {
407         MS_LOG(ERROR) << "Cannot have multiple U Monad in one call, first: " << local_u_monad->ToString()
408                       << ", second: " << input->ToString();
409         return false;
410       }
411       local_u_monad = input;
412       continue;
413     }
414     // Should be only one IO Monad input.
415     if (HasAbstractIOMonad(input)) {
416       if (local_io_monad != nullptr) {
417         MS_LOG(ERROR) << "Cannot have multiple IO Monad in one call, first: " << local_io_monad->ToString()
418                       << ", second: " << input->ToString();
419         return false;
420       }
421       local_io_monad = input;
422       continue;
423     }
424     // Collect all non-monad inputs.
425     (void)new_inputs->emplace_back(weak_input);
426   }
427   *possible_u_monad = local_u_monad;
428   *possible_io_monad = local_io_monad;
429   return true;
430 }
431 
432 // After lift, func_graph will not refer any free variable, so DummyContext is proper.
BuildFuncGraphValueNode(const FuncGraphPtr & func_graph,bool preset_abstract)433 AnfNodePtr BuildFuncGraphValueNode(const FuncGraphPtr &func_graph, bool preset_abstract) {
434   auto new_node = NewValueNode(func_graph);
435   auto abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
436     func_graph, abstract::AnalysisContext::DummyContext(), new_node, preset_abstract);
437   new_node->set_abstract(abstract);
438   return new_node;
439 }
440 
BuildPrimitiveValueNode(const PrimitivePtr & primitive)441 AnfNodePtr BuildPrimitiveValueNode(const PrimitivePtr &primitive) {
442   auto new_node = NewValueNode(primitive);
443   auto abstract = std::make_shared<abstract::PrimitiveAbstractClosure>(primitive, new_node);
444   new_node->set_abstract(abstract);
445   return new_node;
446 }
447 
PresetPartialAbstractClosure(const CNodePtr & cnode,const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & weak_inputs,bool preset_abstract)448 void PresetPartialAbstractClosure(const CNodePtr &cnode, const FuncGraphPtr &func_graph,
449                                   const AnfNodeWeakPtrList &weak_inputs, bool preset_abstract) {
450   if (!preset_abstract) {
451     return;
452   }
453   constexpr auto ignore_partial_fg_count = 2;
454   AbstractBasePtrList args_abs_list;
455   (void)std::for_each(weak_inputs.cbegin() + ignore_partial_fg_count, weak_inputs.cend(),
456                       [&args_abs_list](const AnfNodeWeakPtr &weak_node) {
457                         auto node = weak_node.lock();
458                         MS_EXCEPTION_IF_NULL(node);
459                         (void)args_abs_list.emplace_back(node->abstract());
460                       });
461   MS_EXCEPTION_IF_NULL(func_graph->ToAbstract());
462   auto abs = std::make_shared<abstract::PartialAbstractClosure>(
463     func_graph->ToAbstract()->cast<abstract::AbstractFuncAtomPtr>(), args_abs_list, cnode);
464   cnode->set_abstract(abs);
465 }
466 }  // namespace
467 
IsLiftTopFuncGraph(const FuncGraphPtr & func_graph)468 bool Cloner::IsLiftTopFuncGraph(const FuncGraphPtr &func_graph) {
469   const auto &iter = std::find_if(todo_.begin(), todo_.end(),
470                                   [func_graph](const CloneInfo &item) -> bool { return item.origin == func_graph; });
471   if (iter == todo_.end()) {
472     return false;
473   }
474   return true;
475 }
476 
OrderParameters(const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & inputs,size_t arg_start_index)477 void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList &inputs, size_t arg_start_index) {
478   MS_EXCEPTION_IF_NULL(func_graph);
479   mindspore::HashSet<AnfNodePtr> old_params;
480   for (auto &param : func_graph->parameters()) {
481     (void)old_params.insert(replicated_node_[param]);
482   }
483   mindspore::HashSet<AnfNodePtr> new_params;
484   AnfNodePtrList parameters;
485   // Ignore the 1st and 2nd param of inputs(such as. partial graph)
486   for (size_t i = arg_start_index; i < inputs.size(); ++i) {
487     const auto &input = inputs[i].lock();
488     MS_EXCEPTION_IF_NULL(input);
489     const auto &param = replicated_node_[input];
490     if (old_params.find(param) != old_params.end()) {
491       auto &new_param = replicated_map_node_[func_graph][param];
492       (void)parameters.emplace_back(new_param);
493       (void)new_params.insert(new_param);
494     }
495   }
496   for (auto &param : func_graph->parameters()) {
497     if (new_params.find(param) == new_params.end()) {
498       (void)parameters.emplace_back(param);
499     }
500   }
501   func_graph->set_parameters(std::move(parameters));
502 }
503 
504 // Avoid to create nested partial CNode.
SetPartialEdges(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FuncGraphTransaction * tx)505 CNodePtr Cloner::SetPartialEdges(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FuncGraphTransaction *tx) {
506   if (!IsPrimitiveCNode(cnode, prim::kPrimPartial) || !IsValueNode<FuncGraph>(cnode->input(1))) {
507     return nullptr;
508   }
509   auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
510   MS_EXCEPTION_IF_NULL(graph);
511   auto &replicated_func_graph = replicated_map_func_graph_[func_graph];
512   if (replicated_func_graph.find(graph) == replicated_func_graph.end()) {
513     return nullptr;
514   }
515 
516   auto partial_node = replicated_func_graph[graph];
517   if (!IsPrimitiveCNode(partial_node, prim::kPrimPartial)) {
518     return nullptr;
519   }
520   auto partial_cnode = dyn_cast<CNode>(partial_node);
521   MS_EXCEPTION_IF_NULL(partial_cnode);
522   auto value_node = BuildPrimitiveValueNode(prim::kPrimPartial);
523   MS_EXCEPTION_IF_NULL(value_node);
524   auto func_graph_node = BuildFuncGraphValueNode(graph, preset_abstract());
525   MS_EXCEPTION_IF_NULL(func_graph_node);
526   AnfNodeWeakPtrList new_inputs = {value_node, func_graph_node};
527   constexpr auto ignore_partial_fg_count = 2;
528   (void)std::copy(partial_cnode->weak_inputs().cbegin() + ignore_partial_fg_count, partial_cnode->weak_inputs().cend(),
529                   std::back_inserter(new_inputs));
530   (void)std::copy(cnode->weak_inputs().cbegin() + ignore_partial_fg_count, cnode->weak_inputs().cend(),
531                   std::back_inserter(new_inputs));
532   auto new_cnode = func_graph->NewCNodeWeak(std::move(new_inputs));
533   MS_EXCEPTION_IF_NULL(new_cnode);
534   PresetPartialAbstractClosure(new_cnode, graph, new_cnode->weak_inputs(), preset_abstract());
535 
536   MS_LOG(DEBUG) << "Rebuild partial CNode, old_node: " << cnode->DebugString()
537                 << ", partial_cnode: " << partial_cnode->DebugString() << ", new_node: " << new_cnode->DebugString()
538                 << ", new_node abs: " << (new_cnode->abstract() != nullptr ? new_cnode->abstract()->ToString() : "null")
539                 << ", partial " << graph->ToString() << " in " << func_graph->ToString();
540   (void)tx->Replace(cnode, new_cnode);
541   return new_cnode;
542 }
543 
SetEdges(const FuncGraphPtr & func_graph,FuncGraphTransaction * tx)544 void Cloner::SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx) {
545   MS_EXCEPTION_IF_NULL(func_graph);
546   MS_EXCEPTION_IF_NULL(tx);
547   for (auto &node : func_graph->nodes()) {
548     auto cnode = dyn_cast<CNode>(node);
549     // Only cnode needed to be handled
550     if (cnode == nullptr) {
551       continue;
552     }
553 
554     // Avoid to create nested partial CNode.
555     auto old_cnode = cnode;
556     auto new_cnode = SetPartialEdges(func_graph, cnode, tx);
557     if (new_cnode != nullptr) {
558       cnode = new_cnode;
559     }
560 
561     const auto &inputs = cnode->inputs();
562     for (size_t i = 0; i < inputs.size(); ++i) {
563       auto &input = inputs[i];
564       if (IsValueNode<FuncGraph>(input)) {
565         if (i == 1 && new_cnode != nullptr) {
566           continue;
567         }
568         auto graph = GetValueNode<FuncGraphPtr>(input);
569         auto &replicated_func_graph = replicated_map_func_graph_[func_graph];
570         if (replicated_func_graph.find(graph) != replicated_func_graph.end()) {
571           auto partial_node = replicated_func_graph[graph];
572           tx->SetEdge(cnode, static_cast<int>(i), partial_node);
573         }
574       } else {
575         auto &replicated_node = replicated_map_node_[func_graph];
576         if (replicated_node.find(input) != replicated_node.end()) {
577           tx->SetEdge(cnode, static_cast<int>(i), replicated_node[input]);
578         }
579       }
580     }
581   }
582 }
583 
AddParameters(const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & params,AnfNodeWeakPtrList * const lift_params,AnfNodeWeakPtrList * const input_params)584 void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList &params,
585                            AnfNodeWeakPtrList *const lift_params, AnfNodeWeakPtrList *const input_params) {
586   MS_EXCEPTION_IF_NULL(func_graph);
587   MS_EXCEPTION_IF_NULL(lift_params);
588   MS_EXCEPTION_IF_NULL(input_params);
589   AnfNodePtrList parameters;
590   mindspore::HashSet<AnfNodePtr> old_params;
591   for (auto &param : func_graph->parameters()) {
592     auto iter = replicated_node_.find(param);
593     if (iter != replicated_node_.end()) {
594       (void)old_params.insert(iter->second);
595       (void)parameters.emplace_back(param);
596     } else {
597       (void)parameters.emplace_back(AddParameter(func_graph, param, false));
598       (void)old_params.insert(param);
599     }
600   }
601   AnfNodePtr new_param = nullptr;
602   for (auto &weak_param : params) {
603     const auto &param = weak_param.lock();
604     auto old_param = replicated_node_[param];
605     MS_EXCEPTION_IF_NULL(old_param);
606     if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
607       replicated_node_[old_param] = old_param;
608       replicated_map_node_[func_graph][old_param] = old_param;
609       (void)input_params->emplace_back(old_param);
610       continue;
611     }
612     if (old_params.find(old_param) != old_params.end()) {
613       new_param = replicated_map_node_[func_graph][old_param];
614       if (new_param == nullptr) {
615         MS_LOG(INTERNAL_EXCEPTION) << "map_node, func_graph: " << func_graph->ToString()
616                                    << ", old_param: " << old_param->DebugString() << " cannot found";
617       }
618       (void)input_params->emplace_back(new_param);
619       continue;
620     }
621     if (IsLiftTopFuncGraph(func_graph)) {
622       // Don't lift parameter from used_graphs to my parameter if I am the top;
623       replicated_node_[old_param] = old_param;
624       replicated_map_node_[func_graph][old_param] = old_param;
625       MS_EXCEPTION_IF_NULL(old_param->func_graph());
626       replicated_map_node_[old_param->func_graph()][old_param] = old_param;
627       (void)input_params->emplace_back(old_param);
628       MS_LOG(DEBUG) << "Bypass " << old_param->DebugString() << " for top func_graph: " << func_graph->ToString();
629       continue;
630     }
631     new_param = AddParameter(func_graph, old_param, false);
632     (void)parameters.emplace_back(new_param);
633     (void)lift_params->emplace_back(new_param);
634     (void)input_params->emplace_back(new_param);
635   }
636   func_graph->set_parameters(std::move(parameters));
637 }
638 
AddInputs(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & params)639 void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
640                        const AnfNodeWeakPtrList &params) {
641   auto &replicated_func_graph = replicated_map_func_graph_[func_graph_user];
642   auto [iter, inserted] = replicated_func_graph.emplace(func_graph, nullptr);
643   if (inserted) {
644     const auto value_node = BuildPrimitiveValueNode(prim::kPrimPartial);
645     const auto fg_value = BuildFuncGraphValueNode(func_graph, preset_abstract());
646     AnfNodeWeakPtrList cnode_inputs{value_node, fg_value};
647     auto partial_node = func_graph_user->NewCNodeWeak(std::move(cnode_inputs));
648     iter->second = partial_node;
649   }
650   auto cnode = dyn_cast<CNode>(iter->second);
651   if (cnode == nullptr) {
652     return;
653   }
654   AnfNodePtr input_u_monad;
655   AnfNodePtr input_io_monad;
656   AnfNodePtr param_u_monad;
657   AnfNodePtr param_io_monad;
658   AnfNodeWeakPtrList inputs;
659   AnfNodeWeakPtrList add_params;
660   if (!FilterMonadInput(cnode->weak_inputs(), &inputs, &input_u_monad, &input_io_monad)) {
661     constexpr auto recursive_level = 2;
662     MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple U Monad or multiple IO Monad in one CNode, cnode: "
663                                << cnode->DebugString(recursive_level);
664   }
665   if (!FilterMonadInput(params, &add_params, &param_u_monad, &param_io_monad)) {
666     MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple U Monad or multiple IO Monad in Parameters list, func_graph: "
667                                << func_graph->ToString();
668   }
669 
670   // Append new inputs from free variable.
671   constexpr auto caller_first_arg_index = 2;
672   for (size_t i = caller_first_arg_index; i < inputs.size(); i++) {
673     auto input = inputs[i].lock();
674     auto pos = std::find_if(add_params.cbegin(), add_params.cend(), [&input](const auto &weak_param) {
675       if (weak_param.lock() != nullptr && weak_param.lock() == input) {
676         return true;
677       }
678       return false;
679     });
680     if (pos != add_params.end()) {
681       (void)add_params.erase(pos);
682     }
683   }
684   (void)inputs.insert(inputs.end(), add_params.cbegin(), add_params.cend());
685 
686   // Append monad inputs.
687   if (input_u_monad != nullptr && param_u_monad != nullptr && input_u_monad != param_u_monad) {
688     MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple U Monad in one call, first: " << input_u_monad->ToString()
689                                << ", second: " << param_u_monad->ToString();
690   }
691   if (input_io_monad != nullptr && param_io_monad != nullptr && input_io_monad != param_io_monad) {
692     MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple IO Monad in one call, first: " << input_io_monad->ToString()
693                                << ", second: " << param_io_monad->ToString();
694   }
695   auto &u_monad = (input_u_monad != nullptr ? input_u_monad : param_u_monad);
696   auto &io_monad = (input_io_monad != nullptr ? input_io_monad : param_io_monad);
697   if (u_monad != nullptr) {
698     (void)inputs.emplace_back(u_monad);
699   }
700   if (io_monad != nullptr) {
701     (void)inputs.emplace_back(io_monad);
702   }
703 
704   cnode->set_weak_inputs(inputs);
705   OrderParameters(func_graph, inputs, caller_first_arg_index);
706   PresetPartialAbstractClosure(cnode, func_graph, inputs, preset_abstract());
707   MS_LOG(DEBUG) << "Create new partial CNode: " << cnode->DebugString();
708 }
709 
LiftParameters(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & params)710 void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
711                             const AnfNodeWeakPtrList &params) {
712   MS_EXCEPTION_IF_NULL(func_graph_user);
713   AnfNodeWeakPtrList lift_params;
714   AnfNodeWeakPtrList input_params;
715   AddParameters(func_graph_user, params, &lift_params, &input_params);
716   AddInputs(func_graph_user, func_graph, input_params);
717   if (lift_params.empty()) {
718     return;
719   }
720   for (auto &cnode_index : func_graph_user->func_graph_cnodes_index()) {
721     MS_EXCEPTION_IF_NULL(cnode_index.first);
722     const auto &user_node = cnode_index.first->first;
723     MS_EXCEPTION_IF_NULL(user_node);
724     LiftParameters(user_node->func_graph(), func_graph_user, lift_params);
725   }
726 }
727 
Lift(const std::vector<FuncGraphPtr> & sorted)728 void Cloner::Lift(const std::vector<FuncGraphPtr> &sorted) {
729   // lift inner graph first
730   for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
731     auto func_graph = *r_iter;
732     auto iter = replicated_func_graph_params_.find(func_graph);
733     if (iter != replicated_func_graph_params_.end()) {
734       auto &params = iter->second;
735       for (auto &cnode_index : func_graph->func_graph_cnodes_index()) {
736         MS_EXCEPTION_IF_NULL(cnode_index.first);
737         const auto &user_node = cnode_index.first->first;
738         MS_EXCEPTION_IF_NULL(user_node);
739         LiftParameters(user_node->func_graph(), func_graph, params);
740       }
741     }
742   }
743 }
744 
SetEdgesBfs(const FuncGraphPtr & root_fg,FuncGraphTransaction * tx)745 void Cloner::SetEdgesBfs(const FuncGraphPtr &root_fg, FuncGraphTransaction *tx) {
746   MS_EXCEPTION_IF_NULL(root_fg);
747   const auto &func_graphs = BroadFirstSearchGraphUsed(root_fg, lifting_func_graph_filter());
748   for (auto &func_graph : func_graphs) {
749     SetEdges(func_graph, tx);
750   }
751 }
752 
LiftParameters(const FuncGraphVector & todo_func_graphs)753 void Cloner::LiftParameters(const FuncGraphVector &todo_func_graphs) {
754   MS_EXCEPTION_IF_NULL(manager_);
755   auto tx = manager_->Transact();
756   for (const auto &todo_func_graph : todo_func_graphs) {
757     const auto &func_graphs = BroadFirstSearchGraphUsed(todo_func_graph, lifting_func_graph_filter());
758     for (auto &func_graph : func_graphs) {
759       GenParameters(func_graph);
760     }
761     Lift(func_graphs);
762   }
763   const auto &roots = manager_->roots();
764   // Roots in manager is not set in Pynative mode.
765   if (roots.empty()) {
766     for (const auto &todo_func_graph : todo_func_graphs) {
767       SetEdgesBfs(todo_func_graph, &tx);
768     }
769   } else {
770     for (const auto &root_func_graph : roots) {
771       SetEdgesBfs(root_func_graph, &tx);
772     }
773   }
774   tx.Commit();
775 }
776 
CheckStatus(const FuncGraphPtr & func_graph,bool is_inline)777 bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
778   MS_EXCEPTION_IF_NULL(func_graph);
779   // Make sure only inline once
780   auto iter = status_.find(func_graph);
781   if (iter != status_.end()) {
782     if (is_inline == iter->second) {
783       return false;
784     }
785     if (clone_all_used_graphs_) {
786       MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False.";
787       return false;
788     }
789   }
790   return true;
791 }
792 
CloneAllNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)793 void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
794   MS_EXCEPTION_IF_NULL(func_graph);
795   MS_EXCEPTION_IF_NULL(target_func_graph);
796   const AnfNodeSet &nodes = func_graph->nodes();
797   replicated_node_.reserve(replicated_node_.size() + nodes.size());
798   for (auto &node : nodes) {
799     CloneNode(node, target_func_graph);
800   }
801   // Only func_graph is inlined, it cannot be found in repl;
802   if (replicated_func_graph_.find(func_graph) != replicated_func_graph_.end()) {
803     CloneOrderList(func_graph, target_func_graph);
804   }
805 }
806 
CloneOrderList(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)807 void Cloner::CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
808   for (auto &weak_cnode : func_graph->order_list()) {
809     const auto &cnode = weak_cnode.lock();
810     if (cnode == nullptr) {
811       continue;
812     }
813     auto it = replicated_node_.find(cnode);
814     if (it == replicated_node_.end()) {
815       // For cnode which generated in Analyze phase, it cannot got from nodes API of func_graph,
816       // so it cannot be cloned in normal Clone API.
817       // If we ignore it, the order will be lost.
818       // Therefore we put this old node as placeholder to the order list of target func_graph to
819       // keep the order.
820       // It may be replaced in ProgramSpecialize.
821       // If this disconnected node is not used in target func_graph, it will be cleared after
822       // ProgramSpecialize;
823       target_func_graph->AppendOrderList(cnode);
824       continue;
825     }
826     auto replicated_cnode = dyn_cast<CNode>(it->second);
827     if (replicated_cnode != nullptr) {
828       target_func_graph->AppendOrderList(replicated_cnode);
829     }
830   }
831 }
832 
Run()833 void Cloner::Run() {
834   if (todo_.empty()) {
835     return;
836   }
837 
838   FuncGraphVector func_graphs;
839   (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
840                        [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
841   if (type_ < kLifting) {
842     // Basic and Inline Clone
843     manager_ = Manage(func_graphs, false);
844     CloneNodes();
845     LinkCNodeEdges();
846     SetDefaults();
847   } else {
848     // Lifting Clone
849     manager_ = Manage(func_graphs);
850     LiftParameters(func_graphs);
851   }
852 }
853 
CloneNodes()854 void Cloner::CloneNodes() {
855   while (!todo_.empty()) {
856     CloneInfo item = std::move(todo_.back());
857     todo_.pop_back();
858 
859     const bool is_inline = (item.target != nullptr);
860     FuncGraphPtr &func_graph = item.origin;
861     (void)graph_set_.insert(func_graph);
862 
863     if (!CheckStatus(func_graph, is_inline)) {
864       continue;
865     }
866 
867     if (is_inline) {
868       InlineCloneParameters(func_graph, item.params);
869       CloneAllNodes(func_graph, item.target);
870     } else {
871       auto debug_info = CloneGraphDebugInfo(func_graph->debug_info(), target_relation_);
872       auto target_func_graph = std::make_shared<FuncGraph>(std::move(debug_info));
873       SetFuncGraphInfo(func_graph, target_func_graph);
874       CloneParameters(func_graph, target_func_graph);
875       replicated_func_graph_[func_graph] = target_func_graph;
876       CloneAllNodes(func_graph, target_func_graph);
877       CloneFuncGraphValueNodes(func_graph, target_func_graph);
878       CloneFuncGraphDefaultValues(func_graph, target_func_graph);
879     }
880 
881     CloneValueNodes(func_graph);
882     AddChildGraphs(func_graph);
883     AddTotalGraphs(func_graph);
884     status_[func_graph] = is_inline;
885   }
886 }
887 
888 // Link the CNode with its inputs.
889 // Also see CloneCNodeWithoutInputs()
LinkCNodeEdges()890 void Cloner::LinkCNodeEdges() {
891   for (auto &repl : replicated_node_) {
892     auto old_node = dyn_cast_ptr<CNode>(repl.first);
893     if (old_node == nullptr) {
894       continue;
895     }
896     MS_EXCEPTION_IF_NULL(repl.second);
897     auto new_node = repl.second->cast_ptr<CNode>();
898     MS_EXCEPTION_IF_NULL(new_node);
899     for (auto &weak_input : old_node->weak_inputs()) {
900       auto input = weak_input.lock();
901       MS_EXCEPTION_IF_NULL(input);
902       auto iter = replicated_node_.find(input);
903       auto &new_input = (iter == replicated_node_.end() ? input : iter->second);
904       new_node->add_input(new_input);
905     }
906   }
907 }
908 
909 // For the graphs cloned, update its default value map to the cloned nodes.
SetDefaults()910 void Cloner::SetDefaults() {
911   for (auto &old_fg : graph_set_) {
912     MS_EXCEPTION_IF_NULL(old_fg);
913     auto iter = replicated_func_graph_.find(old_fg);
914     if (iter == replicated_func_graph_.end()) {
915       continue;
916     }
917     auto &new_fg = iter->second;
918     MS_EXCEPTION_IF_NULL(new_fg);
919     for (auto &param_def : old_fg->parameter_default_value()) {
920       auto replicated_iter = replicated_node_.find(param_def.second);
921       auto &value_node = (replicated_iter == replicated_node_.end() ? param_def.second : replicated_iter->second);
922       new_fg->set_param_default_value(param_def.first, value_node);
923     }
924   }
925 }
926 
CloneDisconnected(const AnfNodePtr & root)927 AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
928   MS_EXCEPTION_IF_NULL(root);
929   auto fg_iter = replicated_func_graph_.find(root->func_graph());
930   if (fg_iter == replicated_func_graph_.end()) {
931     MS_EXCEPTION_IF_NULL(root->func_graph());
932     MS_LOG(INTERNAL_EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
933   }
934   CloneNode(root, fg_iter->second);
935   auto iter = replicated_node_.find(root);
936   if (iter == replicated_node_.end()) {
937     MS_LOG(INTERNAL_EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
938   }
939   return iter->second;
940 }
941 
operator [](const AnfNodePtr & node)942 AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
943   {
944     MsProfileStatGuard stat_guard("func_graph_cloner_run.FuncGraphClonerNode");
945     Run();
946   }
947 
948   auto iter = replicated_node_.find(node);
949   return ((iter == replicated_node_.end()) ? node : iter->second);
950 }
951 
operator [](const FuncGraphPtr & func_graph)952 FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
953   MS_EXCEPTION_IF_NULL(func_graph);
954   {
955     MsProfileStatGuard stat_guard("func_graph_cloner_run.FuncGraphClonerGraph");
956     Run();
957   }
958 
959   auto iter = replicated_func_graph_.find(func_graph);
960   auto ret = ((iter == replicated_func_graph_.end()) ? func_graph : iter->second);
961   ret->set_python_obj(func_graph->python_obj());
962   return ret;
963 }
964 
BasicClone(const FuncGraphPtr & func_graph,bool clone_value_nodes,const UpdateInfoPtr update_info)965 FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes, const UpdateInfoPtr update_info) {
966   MS_EXCEPTION_IF_NULL(func_graph);
967   Cloner cloner({func_graph}, clone_value_nodes, true, true);
968   if (update_info != nullptr) {
969     cloner.set_update_info(update_info);
970   }
971   auto target_func_graph = cloner[func_graph];
972   if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
973     MS_EXCEPTION_IF_NULL(target_func_graph);
974     target_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
975   }
976   return target_func_graph;
977 }
978 
InlineClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & func_graph_args,const AnfNodePtr & call_node)979 AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
980                        const AnfNodePtrList &func_graph_args, const AnfNodePtr &call_node) {
981   MS_EXCEPTION_IF_NULL(func_graph);
982   MS_EXCEPTION_IF_NULL(target_func_graph);
983   Cloner cloner({}, false);
984   if (call_node != nullptr) {
985     auto call_cnode = dyn_cast<CNode>(call_node);
986     MS_EXCEPTION_IF_NULL(call_cnode);
987     if (call_cnode->input(0)->scope() != nullptr) {
988       cloner.set_scope(call_cnode->input(0)->scope());
989     }
990   }
991   cloner.set_inline_call_node(call_node);
992   cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
993   if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
994     target_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
995   }
996   if (func_graph->has_flag(kTraining)) {
997     target_func_graph->set_flag(kTraining, true);
998   }
999   return cloner[func_graph->output()];
1000 }
1001 
LiftingClone(const FuncGraphPtr & func_graph,bool preset_abstract,const GraphFilterFunc & lifting_func_graph_filter)1002 FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph, bool preset_abstract,
1003                           const GraphFilterFunc &lifting_func_graph_filter) {
1004   MS_EXCEPTION_IF_NULL(func_graph);
1005   Cloner cloner({}, false);
1006   cloner.set_preset_abstract(preset_abstract);
1007   cloner.set_lifting_func_graph_filter(lifting_func_graph_filter);
1008   cloner.AddClone(func_graph, nullptr, {}, kLifting);
1009   auto target_func_graph = cloner[func_graph];
1010   if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
1011     target_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
1012   }
1013   return target_func_graph;
1014 }
1015 
LiftingCloneMulti(const FuncGraphVector & func_graphs)1016 FuncGraphVector LiftingCloneMulti(const FuncGraphVector &func_graphs) {
1017   Cloner cloner({}, false);
1018   for (const auto &func_graph : func_graphs) {
1019     cloner.AddClone(func_graph, nullptr, {}, kLifting);
1020   }
1021   cloner.Run();
1022 
1023   FuncGraphVector lifted_func_graphs;
1024   const auto &replicated_func_graphs = cloner.cloned_func_graphs();
1025   for (const auto &func_graph : func_graphs) {
1026     auto iter = replicated_func_graphs.find(func_graph);
1027     auto ret = ((iter == replicated_func_graphs.end()) ? func_graph : iter->second);
1028     MS_EXCEPTION_IF_NULL(ret);
1029     ret->set_python_obj(func_graph->python_obj());
1030     (void)lifted_func_graphs.emplace_back(ret);
1031   }
1032 
1033   return lifted_func_graphs;
1034 }
1035 
SpecializerClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)1036 ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
1037   MS_EXCEPTION_IF_NULL(func_graph);
1038   FuncGraphVector func_graphs = {func_graph};
1039   ClonerPtr cloner =
1040     std::make_shared<Cloner>(func_graphs, false, false, false, std::make_shared<TraceCopy>(), relation);
1041   {
1042     MsProfileStatGuard stat_guard("func_graph_cloner_run.FuncGraphSpecializer");
1043     cloner->Run();
1044   }
1045   return cloner;
1046 }
1047 
TransformableClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)1048 FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
1049   MS_EXCEPTION_IF_NULL(func_graph);
1050   auto debug_info = CloneGraphDebugInfo(func_graph->debug_info(), relation);
1051   auto new_func_graph = std::make_shared<FuncGraph>(std::move(debug_info));
1052   for (auto &param : func_graph->parameters()) {
1053     MS_EXCEPTION_IF_NULL(param);
1054     auto param_debug_info = CloneNodeDebugInfo(param->debug_info());
1055     auto new_param = new_func_graph->add_parameter(std::move(param_debug_info));
1056     new_param->set_abstract(param->abstract());
1057   }
1058 
1059   Cloner cloner({}, true);
1060   cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters());
1061   AnfNodePtr output = cloner[func_graph->output()];
1062   new_func_graph->set_output(output);
1063   new_func_graph->set_has_vararg(func_graph->has_vararg());
1064   new_func_graph->set_has_kwarg(func_graph->has_kwarg());
1065   new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
1066   new_func_graph->set_fv_param_count(func_graph->fv_param_count());
1067   new_func_graph->set_is_generate(func_graph->is_generated());
1068   new_func_graph->set_indirect(func_graph->indirect());
1069   new_func_graph->set_stub(func_graph->stub());
1070   for (auto &item : func_graph->parameter_default_value()) {
1071     new_func_graph->set_param_default_value(item.first, cloner[item.second]);
1072   }
1073   if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
1074     new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
1075   }
1076   if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
1077     new_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
1078   }
1079   if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
1080     new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
1081   }
1082   new_func_graph->set_stage(func_graph->stage());
1083   new_func_graph->set_segment(func_graph->segment());
1084   return new_func_graph;
1085 }
1086 }  // namespace mindspore
1087