1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/func_graph_cloner.h"
18 #include <algorithm>
19 #include <set>
20
21 #include "abstract/abstract_function.h"
22 #include "ir/graph_utils.h"
23 #include "ir/manager.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "utils/convert_utils_base.h"
27 #include "utils/log_adapter.h"
28 #include "utils/ms_context.h"
29 #include "utils/parallel_node_check.h"
30 #include "utils/profile.h"
31 #include "utils/trace_base.h"
32
33 // namespace to support intermediate representation definition
34 namespace mindspore {
35 namespace {
CloneNodeDebugInfo(const DebugInfoPtr & debug_info,const TraceInfoPtr & relation)36 NodeDebugInfoPtr CloneNodeDebugInfo(const DebugInfoPtr &debug_info, const TraceInfoPtr &relation) {
37 auto trace_info = relation->clone();
38 trace_info->set_debug_info(debug_info);
39 return std::make_shared<NodeDebugInfo>(std::move(trace_info));
40 }
41
CloneNodeDebugInfo(const NodeDebugInfoPtr & debug_info)42 NodeDebugInfoPtr CloneNodeDebugInfo(const NodeDebugInfoPtr &debug_info) {
43 auto trace_info = std::make_shared<TraceCopy>(debug_info);
44 return std::make_shared<NodeDebugInfo>(std::move(trace_info));
45 }
46
CloneGraphDebugInfo(const GraphDebugInfoPtr & debug_info,const TraceInfoPtr & relation)47 GraphDebugInfoPtr CloneGraphDebugInfo(const GraphDebugInfoPtr &debug_info, const TraceInfoPtr &relation) {
48 auto trace_info = relation->clone();
49 trace_info->set_debug_info(debug_info);
50 return std::make_shared<GraphDebugInfo>(std::move(trace_info));
51 }
52 } // namespace
53
Cloner(const FuncGraphVector & func_graphs,bool clone_all_valuenodes,bool clone_all_child_graphs,bool clone_all_used_graphs,const TraceInfoPtr & relation,const TraceInfoPtr & target_relation)54 Cloner::Cloner(const FuncGraphVector &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
55 bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
56 : clone_all_valuenodes_(clone_all_valuenodes),
57 clone_all_child_graphs_(clone_all_child_graphs),
58 clone_all_used_graphs_(clone_all_used_graphs),
59 relation_(relation),
60 target_relation_(target_relation == nullptr ? relation : target_relation),
61 scope_(kDefaultScope),
62 type_(kBasic) {
63 for (auto &func_graph : func_graphs) {
64 AddClone(func_graph);
65 }
66 }
67
AddClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & params,CloneType type)68 void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
69 const AnfNodePtrList ¶ms, CloneType type) {
70 if (func_graph != nullptr) {
71 (void)todo_.emplace_back(CloneInfo{func_graph, target_func_graph, params});
72 type_ = type;
73 }
74 }
75
CloneNode(const AnfNodePtr & node,const FuncGraphPtr & target)76 void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
77 MS_EXCEPTION_IF_NULL(node);
78 if (replicated_node_.find(node) != replicated_node_.end()) {
79 return;
80 }
81 if (node->isa<CNode>()) {
82 CloneCNodeWithoutInputs(node, target);
83 } else if (node->isa<Parameter>()) {
84 CloneParameter(node, target, false);
85 }
86 }
87
CloneParameter(const AnfNodePtr & node,const FuncGraphPtr & target,bool is_add)88 void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
89 MS_EXCEPTION_IF_NULL(node);
90 MS_EXCEPTION_IF_NULL(target);
91 auto old_param = node->cast_ptr<Parameter>();
92 MS_EXCEPTION_IF_NULL(old_param);
93 auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
94 auto new_param = (is_add ? target->add_parameter(std::move(debug_info))
95 : std::make_shared<Parameter>(target, std::move(debug_info)));
96 if (preset_abstract()) {
97 new_param->set_abstract(old_param->abstract());
98 }
99 new_param->set_name(old_param->name());
100 if (old_param->has_default()) {
101 // Default parameter can be shared since it is readonly.
102 new_param->set_default_param(old_param->default_param());
103 }
104 new_param->set_is_top_graph_param(old_param->is_top_graph_param());
105 ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
106 new_param->set_scope(scope);
107 replicated_node_[node] = std::move(new_param);
108 }
109
110 // Create a new empty CNode for old one, and bind them.
111 // Also see LinkCNodeEdges().
CloneCNodeWithoutInputs(const AnfNodePtr & node,const FuncGraphPtr & target)112 void Cloner::CloneCNodeWithoutInputs(const AnfNodePtr &node, const FuncGraphPtr &target) {
113 MS_EXCEPTION_IF_NULL(node);
114 MS_EXCEPTION_IF_NULL(target);
115 auto old_node = node->cast<CNodePtr>();
116 AnfNodeWeakPtrList inputs;
117 inputs.reserve(old_node->size());
118 DebugInfoPtr debug_info;
119 if (this->update_info() != nullptr && this->update_info()->debug_info_ != nullptr) {
120 debug_info = this->update_info()->debug_info_;
121 } else {
122 debug_info = node->debug_info();
123 }
124
125 auto cloned_debug_info = CloneNodeDebugInfo(debug_info, relation_);
126 CNodePtr new_node = std::make_shared<CNode>(std::move(inputs), target, std::move(cloned_debug_info));
127 if (inline_call_node_ != nullptr) {
128 MS_LOG(DEBUG) << "inline_call_node_: " << inline_call_node_ << "/" << inline_call_node_->DebugString()
129 << ", new_node: " << new_node << "/" << new_node->DebugString();
130 UpdateInlineCNodeDebugInfo(inline_call_node_, new_node);
131 } else {
132 // Synchronize callers' shadow debug infos.
133 auto &new_shadow_debug_infos = new_node->debug_info()->shadow_debug_infos_map();
134 const auto &old_shadow_debug_infos = debug_info->shadow_debug_infos_map();
135 new_shadow_debug_infos.insert(old_shadow_debug_infos.cbegin(), old_shadow_debug_infos.cend());
136 }
137 new_node->CloneCNodeInfo(old_node);
138 // Copy to target graph
139 if (new_node->forward().first != nullptr) {
140 target->set_used_forward_nodes({new_node});
141 }
142 ScopePtr scope;
143 if (this->update_info() != nullptr && this->update_info()->scope_ != nullptr) {
144 scope = this->update_info()->scope_;
145 } else {
146 scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
147 }
148 new_node->set_scope(scope);
149 replicated_node_[node] = std::move(new_node);
150 }
151
CloneValueNode(const AnfNodePtr & node)152 void Cloner::CloneValueNode(const AnfNodePtr &node) {
153 MS_EXCEPTION_IF_NULL(node);
154 auto value_node = node->cast_ptr<ValueNode>();
155 MS_EXCEPTION_IF_NULL(value_node);
156 auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
157 ValueNodePtr new_const = NewValueNode(GetValueNode(node), std::move(debug_info));
158 ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
159 new_const->set_scope(scope);
160 if (preset_abstract()) {
161 new_const->set_abstract(node->abstract());
162 }
163 new_const->set_has_new_value(value_node->has_new_value());
164 replicated_node_[node] = std::move(new_const);
165 }
166
CloneFuncGraphValueNode(const AnfNodePtr & node,const FuncGraphPtr & target)167 void Cloner::CloneFuncGraphValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
168 MS_EXCEPTION_IF_NULL(node);
169 MS_EXCEPTION_IF_NULL(target);
170 auto value_node = node->cast_ptr<ValueNode>();
171 MS_EXCEPTION_IF_NULL(value_node);
172 auto debug_info = CloneNodeDebugInfo(node->debug_info(), relation_);
173 ValueNodePtr new_const = NewValueNode(target, std::move(debug_info));
174 ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
175 new_const->set_scope(scope);
176 if (preset_abstract()) {
177 new_const->set_abstract(node->abstract());
178 }
179 new_const->set_has_new_value(value_node->has_new_value());
180 replicated_node_[node] = std::move(new_const);
181 }
182
CloneValueNodes(const FuncGraphPtr & func_graph)183 void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
184 MS_EXCEPTION_IF_NULL(func_graph);
185 if (!clone_all_valuenodes_) {
186 return;
187 }
188 auto &value_nodes = func_graph->value_nodes();
189 for (auto &value_node : value_nodes) {
190 auto &old_node = value_node.first;
191 if (replicated_node_.find(old_node) == replicated_node_.end()) {
192 CloneValueNode(old_node);
193 }
194 }
195 }
196
AddChildGraphs(const FuncGraphPtr & func_graph)197 void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
198 MS_EXCEPTION_IF_NULL(func_graph);
199 MS_EXCEPTION_IF_NULL(manager_);
200 if (!clone_all_child_graphs_) {
201 return;
202 }
203 // The graph marked 'no_child_graph' has no child graph.
204 if (func_graph->has_flag(FUNC_GRAPH_FLAG_NO_CHILD_GRAPH)) {
205 return;
206 }
207 auto &scopes = manager_->scopes(func_graph);
208 std::set<const FuncGraph *> memo;
209 for (auto &graph : scopes) {
210 // Avoid to insert duplicate function.
211 if (graph == func_graph || !memo.emplace(graph.get()).second) {
212 continue;
213 }
214 (void)todo_.emplace_back(CloneInfo{graph, nullptr, {}});
215 }
216 }
217
AddTotalGraphs(const FuncGraphPtr & func_graph)218 void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
219 MS_EXCEPTION_IF_NULL(func_graph);
220 if (!clone_all_used_graphs_) {
221 return;
222 }
223 std::set<const FuncGraph *> memo;
224 auto &used = func_graph->func_graphs_used();
225 for (auto &fg : used) {
226 // Avoid to insert duplicate function.
227 if (!memo.emplace(fg.first.get()).second) {
228 continue;
229 }
230 (void)todo_.emplace_back(CloneInfo{fg.first, nullptr, {}});
231 }
232 }
233
CloneFuncGraphDefaultValues(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)234 void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
235 MS_EXCEPTION_IF_NULL(func_graph);
236 MS_EXCEPTION_IF_NULL(target_func_graph);
237 for (auto &item : func_graph->parameter_default_value()) {
238 auto nodes = TopoSort(item.second, SuccDeeperSimple);
239 for (auto &node : nodes) {
240 MS_EXCEPTION_IF_NULL(node);
241 if (node->isa<CNode>()) {
242 CloneNode(node, target_func_graph);
243 } else if (node->isa<ValueNode>()) {
244 CloneValueNode(node);
245 }
246 }
247 }
248 }
249
CloneFuncGraphValueNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)250 void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
251 MS_EXCEPTION_IF_NULL(func_graph);
252 MS_EXCEPTION_IF_NULL(target_func_graph);
253
254 target_func_graph->set_stage(func_graph->stage());
255 target_func_graph->set_segment(func_graph->segment());
256 auto &old_return = func_graph->return_node();
257 if (old_return != nullptr) {
258 auto iter = replicated_node_.find(old_return);
259 if (iter == replicated_node_.end()) {
260 MS_LOG(INTERNAL_EXCEPTION) << "Can't find replicate node for return.";
261 }
262 MS_EXCEPTION_IF_NULL(iter->second);
263 auto return_node = iter->second->cast<CNodePtr>();
264 MS_EXCEPTION_IF_NULL(return_node);
265 target_func_graph->set_return(return_node);
266 } else {
267 MS_LOG(ERROR) << "Has no return node, func_graph: " << func_graph << "/" << func_graph->ToString();
268 }
269
270 auto &cnodes = func_graph->func_graph_cnodes_index();
271 for (auto &cnode : cnodes) {
272 MS_EXCEPTION_IF_NULL(cnode.first);
273 MS_EXCEPTION_IF_NULL(cnode.first->first);
274 auto user_cnode = cnode.first->first->cast_ptr<CNode>();
275 MS_EXCEPTION_IF_NULL(user_cnode);
276 const auto &valuenode = user_cnode->input(IntToSize(cnode.first->second));
277 if (valuenode == nullptr) {
278 continue;
279 }
280 CloneFuncGraphValueNode(valuenode, target_func_graph);
281 }
282 }
283
InlineCloneParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & params)284 void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) {
285 MS_EXCEPTION_IF_NULL(func_graph);
286 auto &old_params = func_graph->parameters();
287 if (old_params.size() != params.size()) {
288 MS_INTERNAL_EXCEPTION(TypeError) << "Origin params size[" << old_params.size() << "], inline params size["
289 << params.size() << "]";
290 }
291 for (size_t i = 0; i < old_params.size(); ++i) {
292 replicated_node_[old_params[i]] = params[i];
293 }
294 }
295
SetFuncGraphInfo(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph) const296 void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) const {
297 MS_EXCEPTION_IF_NULL(func_graph);
298 MS_EXCEPTION_IF_NULL(target_func_graph);
299 target_func_graph->set_attrs(func_graph->attrs());
300 target_func_graph->set_transforms(func_graph->transforms());
301 target_func_graph->set_has_vararg(func_graph->has_vararg());
302 target_func_graph->set_has_kwarg(func_graph->has_kwarg());
303 target_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
304 target_func_graph->set_fv_param_count(func_graph->fv_param_count());
305 target_func_graph->set_is_generate(func_graph->is_generated());
306 target_func_graph->set_stub(func_graph->stub());
307 target_func_graph->set_indirect(func_graph->indirect());
308 target_func_graph->set_python_obj(func_graph->python_obj());
309 target_func_graph->set_has_side_effect_node(func_graph->has_side_effect_node());
310 }
311
CloneParameters(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)312 void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
313 MS_EXCEPTION_IF_NULL(func_graph);
314 MS_EXCEPTION_IF_NULL(target_func_graph);
315 auto ¶ms = func_graph->parameters();
316 for (auto ¶m : params) {
317 CloneParameter(param, target_func_graph, true);
318 }
319 }
320
GenParameters(const FuncGraphPtr & func_graph)321 void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
322 MS_EXCEPTION_IF_NULL(func_graph);
323 auto &free_vars = manager_->free_variables_total();
324 auto iter = free_vars.find(func_graph);
325 if (iter == free_vars.end()) {
326 return;
327 }
328
329 CloneInfo item = todo_.back();
330 auto lift_top_func_graph = item.origin;
331 for (auto &fv_map : iter->second) {
332 auto &free_var = fv_map.first;
333 if (!utils::isa<AnfNodePtr>(free_var)) {
334 continue;
335 }
336 auto free_var_node = utils::cast<AnfNodePtr>(free_var);
337 // Don't lift weight parameter to top func_graph.
338 if (IsLiftTopFuncGraph(func_graph) && free_var_node->isa<Parameter>()) {
339 auto free_var_param = free_var_node->cast_ptr<Parameter>();
340 if (free_var_param->has_default()) {
341 MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->DebugString()
342 << " for top_func_graph: " << lift_top_func_graph->ToString();
343 continue;
344 }
345 }
346 auto &replicated_node = replicated_map_node_[func_graph];
347 if (replicated_node.find(free_var_node) != replicated_node.end()) {
348 MS_LOG(DEBUG) << "Param exists: " << free_var_node->DebugString()
349 << " for func_graph: " << func_graph->ToString();
350 continue;
351 }
352
353 MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
354 auto fv_parameter = AddParameter(func_graph, free_var_node);
355 fv_parameter->set_user_data<bool>("lifted_from_fv", std::make_shared<bool>(true));
356 auto &fg_params = replicated_func_graph_params_[func_graph];
357 (void)fg_params.emplace_back(fv_parameter);
358 }
359 }
360
CloneParameter(const ParameterPtr & param,const AnfNodePtr & node) const361 void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) const {
362 MS_EXCEPTION_IF_NULL(param);
363 MS_EXCEPTION_IF_NULL(node);
364 if (preset_abstract()) {
365 param->set_abstract(node->abstract());
366 }
367 if (node->isa<Parameter>()) {
368 auto old_param = node->cast_ptr<Parameter>();
369 if (old_param->has_default()) {
370 // Default parameter can be shared since it is readonly.
371 param->set_default_param(old_param->default_param());
372 }
373 param->set_name(old_param->name());
374 constexpr char lifted_user_data_key[] = "lifted_from_fv";
375 auto lifted = param->user_data<bool>(lifted_user_data_key);
376 if (lifted != nullptr && *lifted) {
377 param->set_user_data<bool>(lifted_user_data_key, std::make_shared<bool>(true));
378 }
379 }
380 }
381
AddParameter(const FuncGraphPtr & func_graph,const AnfNodePtr & node,bool is_add)382 ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
383 MS_EXCEPTION_IF_NULL(func_graph);
384 MS_EXCEPTION_IF_NULL(node);
385 auto debug_info = CloneNodeDebugInfo(node->debug_info());
386 ParameterPtr param = std::make_shared<Parameter>(func_graph, std::move(debug_info));
387 CloneParameter(param, node);
388 if (is_add) {
389 func_graph->add_parameter(param);
390 }
391 replicated_node_[param] = node;
392 replicated_map_node_[func_graph][node] = param;
393 return param;
394 }
395
396 namespace {
FilterMonadInput(const AnfNodeWeakPtrList & old_inputs,AnfNodeWeakPtrList * new_inputs,AnfNodePtr * possible_u_monad,AnfNodePtr * possible_io_monad)397 bool FilterMonadInput(const AnfNodeWeakPtrList &old_inputs, AnfNodeWeakPtrList *new_inputs,
398 AnfNodePtr *possible_u_monad, AnfNodePtr *possible_io_monad) {
399 AnfNodePtr local_u_monad = nullptr;
400 AnfNodePtr local_io_monad = nullptr;
401 for (const auto &weak_input : old_inputs) {
402 auto input = weak_input.lock();
403 MS_EXCEPTION_IF_NULL(input);
404 // Should be only one U Monad input.
405 if (HasAbstractUMonad(input)) {
406 if (local_u_monad != nullptr) {
407 MS_LOG(ERROR) << "Cannot have multiple U Monad in one call, first: " << local_u_monad->ToString()
408 << ", second: " << input->ToString();
409 return false;
410 }
411 local_u_monad = input;
412 continue;
413 }
414 // Should be only one IO Monad input.
415 if (HasAbstractIOMonad(input)) {
416 if (local_io_monad != nullptr) {
417 MS_LOG(ERROR) << "Cannot have multiple IO Monad in one call, first: " << local_io_monad->ToString()
418 << ", second: " << input->ToString();
419 return false;
420 }
421 local_io_monad = input;
422 continue;
423 }
424 // Collect all non-monad inputs.
425 (void)new_inputs->emplace_back(weak_input);
426 }
427 *possible_u_monad = local_u_monad;
428 *possible_io_monad = local_io_monad;
429 return true;
430 }
431
432 // After lift, func_graph will not refer any free variable, so DummyContext is proper.
BuildFuncGraphValueNode(const FuncGraphPtr & func_graph,bool preset_abstract)433 AnfNodePtr BuildFuncGraphValueNode(const FuncGraphPtr &func_graph, bool preset_abstract) {
434 auto new_node = NewValueNode(func_graph);
435 auto abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(
436 func_graph, abstract::AnalysisContext::DummyContext(), new_node, preset_abstract);
437 new_node->set_abstract(abstract);
438 return new_node;
439 }
440
BuildPrimitiveValueNode(const PrimitivePtr & primitive)441 AnfNodePtr BuildPrimitiveValueNode(const PrimitivePtr &primitive) {
442 auto new_node = NewValueNode(primitive);
443 auto abstract = std::make_shared<abstract::PrimitiveAbstractClosure>(primitive, new_node);
444 new_node->set_abstract(abstract);
445 return new_node;
446 }
447
PresetPartialAbstractClosure(const CNodePtr & cnode,const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & weak_inputs,bool preset_abstract)448 void PresetPartialAbstractClosure(const CNodePtr &cnode, const FuncGraphPtr &func_graph,
449 const AnfNodeWeakPtrList &weak_inputs, bool preset_abstract) {
450 if (!preset_abstract) {
451 return;
452 }
453 constexpr auto ignore_partial_fg_count = 2;
454 AbstractBasePtrList args_abs_list;
455 (void)std::for_each(weak_inputs.cbegin() + ignore_partial_fg_count, weak_inputs.cend(),
456 [&args_abs_list](const AnfNodeWeakPtr &weak_node) {
457 auto node = weak_node.lock();
458 MS_EXCEPTION_IF_NULL(node);
459 (void)args_abs_list.emplace_back(node->abstract());
460 });
461 MS_EXCEPTION_IF_NULL(func_graph->ToAbstract());
462 auto abs = std::make_shared<abstract::PartialAbstractClosure>(
463 func_graph->ToAbstract()->cast<abstract::AbstractFuncAtomPtr>(), args_abs_list, cnode);
464 cnode->set_abstract(abs);
465 }
466 } // namespace
467
IsLiftTopFuncGraph(const FuncGraphPtr & func_graph)468 bool Cloner::IsLiftTopFuncGraph(const FuncGraphPtr &func_graph) {
469 const auto &iter = std::find_if(todo_.begin(), todo_.end(),
470 [func_graph](const CloneInfo &item) -> bool { return item.origin == func_graph; });
471 if (iter == todo_.end()) {
472 return false;
473 }
474 return true;
475 }
476
OrderParameters(const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & inputs,size_t arg_start_index)477 void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList &inputs, size_t arg_start_index) {
478 MS_EXCEPTION_IF_NULL(func_graph);
479 mindspore::HashSet<AnfNodePtr> old_params;
480 for (auto ¶m : func_graph->parameters()) {
481 (void)old_params.insert(replicated_node_[param]);
482 }
483 mindspore::HashSet<AnfNodePtr> new_params;
484 AnfNodePtrList parameters;
485 // Ignore the 1st and 2nd param of inputs(such as. partial graph)
486 for (size_t i = arg_start_index; i < inputs.size(); ++i) {
487 const auto &input = inputs[i].lock();
488 MS_EXCEPTION_IF_NULL(input);
489 const auto ¶m = replicated_node_[input];
490 if (old_params.find(param) != old_params.end()) {
491 auto &new_param = replicated_map_node_[func_graph][param];
492 (void)parameters.emplace_back(new_param);
493 (void)new_params.insert(new_param);
494 }
495 }
496 for (auto ¶m : func_graph->parameters()) {
497 if (new_params.find(param) == new_params.end()) {
498 (void)parameters.emplace_back(param);
499 }
500 }
501 func_graph->set_parameters(std::move(parameters));
502 }
503
504 // Avoid to create nested partial CNode.
SetPartialEdges(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FuncGraphTransaction * tx)505 CNodePtr Cloner::SetPartialEdges(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FuncGraphTransaction *tx) {
506 if (!IsPrimitiveCNode(cnode, prim::kPrimPartial) || !IsValueNode<FuncGraph>(cnode->input(1))) {
507 return nullptr;
508 }
509 auto graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
510 MS_EXCEPTION_IF_NULL(graph);
511 auto &replicated_func_graph = replicated_map_func_graph_[func_graph];
512 if (replicated_func_graph.find(graph) == replicated_func_graph.end()) {
513 return nullptr;
514 }
515
516 auto partial_node = replicated_func_graph[graph];
517 if (!IsPrimitiveCNode(partial_node, prim::kPrimPartial)) {
518 return nullptr;
519 }
520 auto partial_cnode = dyn_cast<CNode>(partial_node);
521 MS_EXCEPTION_IF_NULL(partial_cnode);
522 auto value_node = BuildPrimitiveValueNode(prim::kPrimPartial);
523 MS_EXCEPTION_IF_NULL(value_node);
524 auto func_graph_node = BuildFuncGraphValueNode(graph, preset_abstract());
525 MS_EXCEPTION_IF_NULL(func_graph_node);
526 AnfNodeWeakPtrList new_inputs = {value_node, func_graph_node};
527 constexpr auto ignore_partial_fg_count = 2;
528 (void)std::copy(partial_cnode->weak_inputs().cbegin() + ignore_partial_fg_count, partial_cnode->weak_inputs().cend(),
529 std::back_inserter(new_inputs));
530 (void)std::copy(cnode->weak_inputs().cbegin() + ignore_partial_fg_count, cnode->weak_inputs().cend(),
531 std::back_inserter(new_inputs));
532 auto new_cnode = func_graph->NewCNodeWeak(std::move(new_inputs));
533 MS_EXCEPTION_IF_NULL(new_cnode);
534 PresetPartialAbstractClosure(new_cnode, graph, new_cnode->weak_inputs(), preset_abstract());
535
536 MS_LOG(DEBUG) << "Rebuild partial CNode, old_node: " << cnode->DebugString()
537 << ", partial_cnode: " << partial_cnode->DebugString() << ", new_node: " << new_cnode->DebugString()
538 << ", new_node abs: " << (new_cnode->abstract() != nullptr ? new_cnode->abstract()->ToString() : "null")
539 << ", partial " << graph->ToString() << " in " << func_graph->ToString();
540 (void)tx->Replace(cnode, new_cnode);
541 return new_cnode;
542 }
543
SetEdges(const FuncGraphPtr & func_graph,FuncGraphTransaction * tx)544 void Cloner::SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx) {
545 MS_EXCEPTION_IF_NULL(func_graph);
546 MS_EXCEPTION_IF_NULL(tx);
547 for (auto &node : func_graph->nodes()) {
548 auto cnode = dyn_cast<CNode>(node);
549 // Only cnode needed to be handled
550 if (cnode == nullptr) {
551 continue;
552 }
553
554 // Avoid to create nested partial CNode.
555 auto old_cnode = cnode;
556 auto new_cnode = SetPartialEdges(func_graph, cnode, tx);
557 if (new_cnode != nullptr) {
558 cnode = new_cnode;
559 }
560
561 const auto &inputs = cnode->inputs();
562 for (size_t i = 0; i < inputs.size(); ++i) {
563 auto &input = inputs[i];
564 if (IsValueNode<FuncGraph>(input)) {
565 if (i == 1 && new_cnode != nullptr) {
566 continue;
567 }
568 auto graph = GetValueNode<FuncGraphPtr>(input);
569 auto &replicated_func_graph = replicated_map_func_graph_[func_graph];
570 if (replicated_func_graph.find(graph) != replicated_func_graph.end()) {
571 auto partial_node = replicated_func_graph[graph];
572 tx->SetEdge(cnode, static_cast<int>(i), partial_node);
573 }
574 } else {
575 auto &replicated_node = replicated_map_node_[func_graph];
576 if (replicated_node.find(input) != replicated_node.end()) {
577 tx->SetEdge(cnode, static_cast<int>(i), replicated_node[input]);
578 }
579 }
580 }
581 }
582 }
583
AddParameters(const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & params,AnfNodeWeakPtrList * const lift_params,AnfNodeWeakPtrList * const input_params)584 void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodeWeakPtrList ¶ms,
585 AnfNodeWeakPtrList *const lift_params, AnfNodeWeakPtrList *const input_params) {
586 MS_EXCEPTION_IF_NULL(func_graph);
587 MS_EXCEPTION_IF_NULL(lift_params);
588 MS_EXCEPTION_IF_NULL(input_params);
589 AnfNodePtrList parameters;
590 mindspore::HashSet<AnfNodePtr> old_params;
591 for (auto ¶m : func_graph->parameters()) {
592 auto iter = replicated_node_.find(param);
593 if (iter != replicated_node_.end()) {
594 (void)old_params.insert(iter->second);
595 (void)parameters.emplace_back(param);
596 } else {
597 (void)parameters.emplace_back(AddParameter(func_graph, param, false));
598 (void)old_params.insert(param);
599 }
600 }
601 AnfNodePtr new_param = nullptr;
602 for (auto &weak_param : params) {
603 const auto ¶m = weak_param.lock();
604 auto old_param = replicated_node_[param];
605 MS_EXCEPTION_IF_NULL(old_param);
606 if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
607 replicated_node_[old_param] = old_param;
608 replicated_map_node_[func_graph][old_param] = old_param;
609 (void)input_params->emplace_back(old_param);
610 continue;
611 }
612 if (old_params.find(old_param) != old_params.end()) {
613 new_param = replicated_map_node_[func_graph][old_param];
614 if (new_param == nullptr) {
615 MS_LOG(INTERNAL_EXCEPTION) << "map_node, func_graph: " << func_graph->ToString()
616 << ", old_param: " << old_param->DebugString() << " cannot found";
617 }
618 (void)input_params->emplace_back(new_param);
619 continue;
620 }
621 if (IsLiftTopFuncGraph(func_graph)) {
622 // Don't lift parameter from used_graphs to my parameter if I am the top;
623 replicated_node_[old_param] = old_param;
624 replicated_map_node_[func_graph][old_param] = old_param;
625 MS_EXCEPTION_IF_NULL(old_param->func_graph());
626 replicated_map_node_[old_param->func_graph()][old_param] = old_param;
627 (void)input_params->emplace_back(old_param);
628 MS_LOG(DEBUG) << "Bypass " << old_param->DebugString() << " for top func_graph: " << func_graph->ToString();
629 continue;
630 }
631 new_param = AddParameter(func_graph, old_param, false);
632 (void)parameters.emplace_back(new_param);
633 (void)lift_params->emplace_back(new_param);
634 (void)input_params->emplace_back(new_param);
635 }
636 func_graph->set_parameters(std::move(parameters));
637 }
638
AddInputs(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & params)639 void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
640 const AnfNodeWeakPtrList ¶ms) {
641 auto &replicated_func_graph = replicated_map_func_graph_[func_graph_user];
642 auto [iter, inserted] = replicated_func_graph.emplace(func_graph, nullptr);
643 if (inserted) {
644 const auto value_node = BuildPrimitiveValueNode(prim::kPrimPartial);
645 const auto fg_value = BuildFuncGraphValueNode(func_graph, preset_abstract());
646 AnfNodeWeakPtrList cnode_inputs{value_node, fg_value};
647 auto partial_node = func_graph_user->NewCNodeWeak(std::move(cnode_inputs));
648 iter->second = partial_node;
649 }
650 auto cnode = dyn_cast<CNode>(iter->second);
651 if (cnode == nullptr) {
652 return;
653 }
654 AnfNodePtr input_u_monad;
655 AnfNodePtr input_io_monad;
656 AnfNodePtr param_u_monad;
657 AnfNodePtr param_io_monad;
658 AnfNodeWeakPtrList inputs;
659 AnfNodeWeakPtrList add_params;
660 if (!FilterMonadInput(cnode->weak_inputs(), &inputs, &input_u_monad, &input_io_monad)) {
661 constexpr auto recursive_level = 2;
662 MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple U Monad or multiple IO Monad in one CNode, cnode: "
663 << cnode->DebugString(recursive_level);
664 }
665 if (!FilterMonadInput(params, &add_params, ¶m_u_monad, ¶m_io_monad)) {
666 MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple U Monad or multiple IO Monad in Parameters list, func_graph: "
667 << func_graph->ToString();
668 }
669
670 // Append new inputs from free variable.
671 constexpr auto caller_first_arg_index = 2;
672 for (size_t i = caller_first_arg_index; i < inputs.size(); i++) {
673 auto input = inputs[i].lock();
674 auto pos = std::find_if(add_params.cbegin(), add_params.cend(), [&input](const auto &weak_param) {
675 if (weak_param.lock() != nullptr && weak_param.lock() == input) {
676 return true;
677 }
678 return false;
679 });
680 if (pos != add_params.end()) {
681 (void)add_params.erase(pos);
682 }
683 }
684 (void)inputs.insert(inputs.end(), add_params.cbegin(), add_params.cend());
685
686 // Append monad inputs.
687 if (input_u_monad != nullptr && param_u_monad != nullptr && input_u_monad != param_u_monad) {
688 MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple U Monad in one call, first: " << input_u_monad->ToString()
689 << ", second: " << param_u_monad->ToString();
690 }
691 if (input_io_monad != nullptr && param_io_monad != nullptr && input_io_monad != param_io_monad) {
692 MS_LOG(INTERNAL_EXCEPTION) << "Cannot have multiple IO Monad in one call, first: " << input_io_monad->ToString()
693 << ", second: " << param_io_monad->ToString();
694 }
695 auto &u_monad = (input_u_monad != nullptr ? input_u_monad : param_u_monad);
696 auto &io_monad = (input_io_monad != nullptr ? input_io_monad : param_io_monad);
697 if (u_monad != nullptr) {
698 (void)inputs.emplace_back(u_monad);
699 }
700 if (io_monad != nullptr) {
701 (void)inputs.emplace_back(io_monad);
702 }
703
704 cnode->set_weak_inputs(inputs);
705 OrderParameters(func_graph, inputs, caller_first_arg_index);
706 PresetPartialAbstractClosure(cnode, func_graph, inputs, preset_abstract());
707 MS_LOG(DEBUG) << "Create new partial CNode: " << cnode->DebugString();
708 }
709
LiftParameters(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodeWeakPtrList & params)710 void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
711 const AnfNodeWeakPtrList ¶ms) {
712 MS_EXCEPTION_IF_NULL(func_graph_user);
713 AnfNodeWeakPtrList lift_params;
714 AnfNodeWeakPtrList input_params;
715 AddParameters(func_graph_user, params, &lift_params, &input_params);
716 AddInputs(func_graph_user, func_graph, input_params);
717 if (lift_params.empty()) {
718 return;
719 }
720 for (auto &cnode_index : func_graph_user->func_graph_cnodes_index()) {
721 MS_EXCEPTION_IF_NULL(cnode_index.first);
722 const auto &user_node = cnode_index.first->first;
723 MS_EXCEPTION_IF_NULL(user_node);
724 LiftParameters(user_node->func_graph(), func_graph_user, lift_params);
725 }
726 }
727
Lift(const std::vector<FuncGraphPtr> & sorted)728 void Cloner::Lift(const std::vector<FuncGraphPtr> &sorted) {
729 // lift inner graph first
730 for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
731 auto func_graph = *r_iter;
732 auto iter = replicated_func_graph_params_.find(func_graph);
733 if (iter != replicated_func_graph_params_.end()) {
734 auto ¶ms = iter->second;
735 for (auto &cnode_index : func_graph->func_graph_cnodes_index()) {
736 MS_EXCEPTION_IF_NULL(cnode_index.first);
737 const auto &user_node = cnode_index.first->first;
738 MS_EXCEPTION_IF_NULL(user_node);
739 LiftParameters(user_node->func_graph(), func_graph, params);
740 }
741 }
742 }
743 }
744
SetEdgesBfs(const FuncGraphPtr & root_fg,FuncGraphTransaction * tx)745 void Cloner::SetEdgesBfs(const FuncGraphPtr &root_fg, FuncGraphTransaction *tx) {
746 MS_EXCEPTION_IF_NULL(root_fg);
747 const auto &func_graphs = BroadFirstSearchGraphUsed(root_fg, lifting_func_graph_filter());
748 for (auto &func_graph : func_graphs) {
749 SetEdges(func_graph, tx);
750 }
751 }
752
LiftParameters(const FuncGraphVector & todo_func_graphs)753 void Cloner::LiftParameters(const FuncGraphVector &todo_func_graphs) {
754 MS_EXCEPTION_IF_NULL(manager_);
755 auto tx = manager_->Transact();
756 for (const auto &todo_func_graph : todo_func_graphs) {
757 const auto &func_graphs = BroadFirstSearchGraphUsed(todo_func_graph, lifting_func_graph_filter());
758 for (auto &func_graph : func_graphs) {
759 GenParameters(func_graph);
760 }
761 Lift(func_graphs);
762 }
763 const auto &roots = manager_->roots();
764 // Roots in manager is not set in Pynative mode.
765 if (roots.empty()) {
766 for (const auto &todo_func_graph : todo_func_graphs) {
767 SetEdgesBfs(todo_func_graph, &tx);
768 }
769 } else {
770 for (const auto &root_func_graph : roots) {
771 SetEdgesBfs(root_func_graph, &tx);
772 }
773 }
774 tx.Commit();
775 }
776
CheckStatus(const FuncGraphPtr & func_graph,bool is_inline)777 bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
778 MS_EXCEPTION_IF_NULL(func_graph);
779 // Make sure only inline once
780 auto iter = status_.find(func_graph);
781 if (iter != status_.end()) {
782 if (is_inline == iter->second) {
783 return false;
784 }
785 if (clone_all_used_graphs_) {
786 MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False.";
787 return false;
788 }
789 }
790 return true;
791 }
792
CloneAllNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)793 void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
794 MS_EXCEPTION_IF_NULL(func_graph);
795 MS_EXCEPTION_IF_NULL(target_func_graph);
796 const AnfNodeSet &nodes = func_graph->nodes();
797 replicated_node_.reserve(replicated_node_.size() + nodes.size());
798 for (auto &node : nodes) {
799 CloneNode(node, target_func_graph);
800 }
801 // Only func_graph is inlined, it cannot be found in repl;
802 if (replicated_func_graph_.find(func_graph) != replicated_func_graph_.end()) {
803 CloneOrderList(func_graph, target_func_graph);
804 }
805 }
806
CloneOrderList(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)807 void Cloner::CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
808 for (auto &weak_cnode : func_graph->order_list()) {
809 const auto &cnode = weak_cnode.lock();
810 if (cnode == nullptr) {
811 continue;
812 }
813 auto it = replicated_node_.find(cnode);
814 if (it == replicated_node_.end()) {
815 // For cnode which generated in Analyze phase, it cannot got from nodes API of func_graph,
816 // so it cannot be cloned in normal Clone API.
817 // If we ignore it, the order will be lost.
818 // Therefore we put this old node as placeholder to the order list of target func_graph to
819 // keep the order.
820 // It may be replaced in ProgramSpecialize.
821 // If this disconnected node is not used in target func_graph, it will be cleared after
822 // ProgramSpecialize;
823 target_func_graph->AppendOrderList(cnode);
824 continue;
825 }
826 auto replicated_cnode = dyn_cast<CNode>(it->second);
827 if (replicated_cnode != nullptr) {
828 target_func_graph->AppendOrderList(replicated_cnode);
829 }
830 }
831 }
832
Run()833 void Cloner::Run() {
834 if (todo_.empty()) {
835 return;
836 }
837
838 FuncGraphVector func_graphs;
839 (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
840 [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
841 if (type_ < kLifting) {
842 // Basic and Inline Clone
843 manager_ = Manage(func_graphs, false);
844 CloneNodes();
845 LinkCNodeEdges();
846 SetDefaults();
847 } else {
848 // Lifting Clone
849 manager_ = Manage(func_graphs);
850 LiftParameters(func_graphs);
851 }
852 }
853
CloneNodes()854 void Cloner::CloneNodes() {
855 while (!todo_.empty()) {
856 CloneInfo item = std::move(todo_.back());
857 todo_.pop_back();
858
859 const bool is_inline = (item.target != nullptr);
860 FuncGraphPtr &func_graph = item.origin;
861 (void)graph_set_.insert(func_graph);
862
863 if (!CheckStatus(func_graph, is_inline)) {
864 continue;
865 }
866
867 if (is_inline) {
868 InlineCloneParameters(func_graph, item.params);
869 CloneAllNodes(func_graph, item.target);
870 } else {
871 auto debug_info = CloneGraphDebugInfo(func_graph->debug_info(), target_relation_);
872 auto target_func_graph = std::make_shared<FuncGraph>(std::move(debug_info));
873 SetFuncGraphInfo(func_graph, target_func_graph);
874 CloneParameters(func_graph, target_func_graph);
875 replicated_func_graph_[func_graph] = target_func_graph;
876 CloneAllNodes(func_graph, target_func_graph);
877 CloneFuncGraphValueNodes(func_graph, target_func_graph);
878 CloneFuncGraphDefaultValues(func_graph, target_func_graph);
879 }
880
881 CloneValueNodes(func_graph);
882 AddChildGraphs(func_graph);
883 AddTotalGraphs(func_graph);
884 status_[func_graph] = is_inline;
885 }
886 }
887
888 // Link the CNode with its inputs.
889 // Also see CloneCNodeWithoutInputs()
LinkCNodeEdges()890 void Cloner::LinkCNodeEdges() {
891 for (auto &repl : replicated_node_) {
892 auto old_node = dyn_cast_ptr<CNode>(repl.first);
893 if (old_node == nullptr) {
894 continue;
895 }
896 MS_EXCEPTION_IF_NULL(repl.second);
897 auto new_node = repl.second->cast_ptr<CNode>();
898 MS_EXCEPTION_IF_NULL(new_node);
899 for (auto &weak_input : old_node->weak_inputs()) {
900 auto input = weak_input.lock();
901 MS_EXCEPTION_IF_NULL(input);
902 auto iter = replicated_node_.find(input);
903 auto &new_input = (iter == replicated_node_.end() ? input : iter->second);
904 new_node->add_input(new_input);
905 }
906 }
907 }
908
909 // For the graphs cloned, update its default value map to the cloned nodes.
SetDefaults()910 void Cloner::SetDefaults() {
911 for (auto &old_fg : graph_set_) {
912 MS_EXCEPTION_IF_NULL(old_fg);
913 auto iter = replicated_func_graph_.find(old_fg);
914 if (iter == replicated_func_graph_.end()) {
915 continue;
916 }
917 auto &new_fg = iter->second;
918 MS_EXCEPTION_IF_NULL(new_fg);
919 for (auto ¶m_def : old_fg->parameter_default_value()) {
920 auto replicated_iter = replicated_node_.find(param_def.second);
921 auto &value_node = (replicated_iter == replicated_node_.end() ? param_def.second : replicated_iter->second);
922 new_fg->set_param_default_value(param_def.first, value_node);
923 }
924 }
925 }
926
CloneDisconnected(const AnfNodePtr & root)927 AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
928 MS_EXCEPTION_IF_NULL(root);
929 auto fg_iter = replicated_func_graph_.find(root->func_graph());
930 if (fg_iter == replicated_func_graph_.end()) {
931 MS_EXCEPTION_IF_NULL(root->func_graph());
932 MS_LOG(INTERNAL_EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
933 }
934 CloneNode(root, fg_iter->second);
935 auto iter = replicated_node_.find(root);
936 if (iter == replicated_node_.end()) {
937 MS_LOG(INTERNAL_EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
938 }
939 return iter->second;
940 }
941
operator [](const AnfNodePtr & node)942 AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
943 {
944 MsProfileStatGuard stat_guard("func_graph_cloner_run.FuncGraphClonerNode");
945 Run();
946 }
947
948 auto iter = replicated_node_.find(node);
949 return ((iter == replicated_node_.end()) ? node : iter->second);
950 }
951
operator [](const FuncGraphPtr & func_graph)952 FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
953 MS_EXCEPTION_IF_NULL(func_graph);
954 {
955 MsProfileStatGuard stat_guard("func_graph_cloner_run.FuncGraphClonerGraph");
956 Run();
957 }
958
959 auto iter = replicated_func_graph_.find(func_graph);
960 auto ret = ((iter == replicated_func_graph_.end()) ? func_graph : iter->second);
961 ret->set_python_obj(func_graph->python_obj());
962 return ret;
963 }
964
BasicClone(const FuncGraphPtr & func_graph,bool clone_value_nodes,const UpdateInfoPtr update_info)965 FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes, const UpdateInfoPtr update_info) {
966 MS_EXCEPTION_IF_NULL(func_graph);
967 Cloner cloner({func_graph}, clone_value_nodes, true, true);
968 if (update_info != nullptr) {
969 cloner.set_update_info(update_info);
970 }
971 auto target_func_graph = cloner[func_graph];
972 if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
973 MS_EXCEPTION_IF_NULL(target_func_graph);
974 target_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
975 }
976 return target_func_graph;
977 }
978
InlineClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & func_graph_args,const AnfNodePtr & call_node)979 AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
980 const AnfNodePtrList &func_graph_args, const AnfNodePtr &call_node) {
981 MS_EXCEPTION_IF_NULL(func_graph);
982 MS_EXCEPTION_IF_NULL(target_func_graph);
983 Cloner cloner({}, false);
984 if (call_node != nullptr) {
985 auto call_cnode = dyn_cast<CNode>(call_node);
986 MS_EXCEPTION_IF_NULL(call_cnode);
987 if (call_cnode->input(0)->scope() != nullptr) {
988 cloner.set_scope(call_cnode->input(0)->scope());
989 }
990 }
991 cloner.set_inline_call_node(call_node);
992 cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
993 if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
994 target_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
995 }
996 if (func_graph->has_flag(kTraining)) {
997 target_func_graph->set_flag(kTraining, true);
998 }
999 return cloner[func_graph->output()];
1000 }
1001
LiftingClone(const FuncGraphPtr & func_graph,bool preset_abstract,const GraphFilterFunc & lifting_func_graph_filter)1002 FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph, bool preset_abstract,
1003 const GraphFilterFunc &lifting_func_graph_filter) {
1004 MS_EXCEPTION_IF_NULL(func_graph);
1005 Cloner cloner({}, false);
1006 cloner.set_preset_abstract(preset_abstract);
1007 cloner.set_lifting_func_graph_filter(lifting_func_graph_filter);
1008 cloner.AddClone(func_graph, nullptr, {}, kLifting);
1009 auto target_func_graph = cloner[func_graph];
1010 if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
1011 target_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
1012 }
1013 return target_func_graph;
1014 }
1015
LiftingCloneMulti(const FuncGraphVector & func_graphs)1016 FuncGraphVector LiftingCloneMulti(const FuncGraphVector &func_graphs) {
1017 Cloner cloner({}, false);
1018 for (const auto &func_graph : func_graphs) {
1019 cloner.AddClone(func_graph, nullptr, {}, kLifting);
1020 }
1021 cloner.Run();
1022
1023 FuncGraphVector lifted_func_graphs;
1024 const auto &replicated_func_graphs = cloner.cloned_func_graphs();
1025 for (const auto &func_graph : func_graphs) {
1026 auto iter = replicated_func_graphs.find(func_graph);
1027 auto ret = ((iter == replicated_func_graphs.end()) ? func_graph : iter->second);
1028 MS_EXCEPTION_IF_NULL(ret);
1029 ret->set_python_obj(func_graph->python_obj());
1030 (void)lifted_func_graphs.emplace_back(ret);
1031 }
1032
1033 return lifted_func_graphs;
1034 }
1035
SpecializerClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)1036 ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
1037 MS_EXCEPTION_IF_NULL(func_graph);
1038 FuncGraphVector func_graphs = {func_graph};
1039 ClonerPtr cloner =
1040 std::make_shared<Cloner>(func_graphs, false, false, false, std::make_shared<TraceCopy>(), relation);
1041 {
1042 MsProfileStatGuard stat_guard("func_graph_cloner_run.FuncGraphSpecializer");
1043 cloner->Run();
1044 }
1045 return cloner;
1046 }
1047
TransformableClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)1048 FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
1049 MS_EXCEPTION_IF_NULL(func_graph);
1050 auto debug_info = CloneGraphDebugInfo(func_graph->debug_info(), relation);
1051 auto new_func_graph = std::make_shared<FuncGraph>(std::move(debug_info));
1052 for (auto ¶m : func_graph->parameters()) {
1053 MS_EXCEPTION_IF_NULL(param);
1054 auto param_debug_info = CloneNodeDebugInfo(param->debug_info());
1055 auto new_param = new_func_graph->add_parameter(std::move(param_debug_info));
1056 new_param->set_abstract(param->abstract());
1057 }
1058
1059 Cloner cloner({}, true);
1060 cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters());
1061 AnfNodePtr output = cloner[func_graph->output()];
1062 new_func_graph->set_output(output);
1063 new_func_graph->set_has_vararg(func_graph->has_vararg());
1064 new_func_graph->set_has_kwarg(func_graph->has_kwarg());
1065 new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
1066 new_func_graph->set_fv_param_count(func_graph->fv_param_count());
1067 new_func_graph->set_is_generate(func_graph->is_generated());
1068 new_func_graph->set_indirect(func_graph->indirect());
1069 new_func_graph->set_stub(func_graph->stub());
1070 for (auto &item : func_graph->parameter_default_value()) {
1071 new_func_graph->set_param_default_value(item.first, cloner[item.second]);
1072 }
1073 if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
1074 new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
1075 }
1076 if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
1077 new_func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
1078 }
1079 if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
1080 new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
1081 }
1082 new_func_graph->set_stage(func_graph->stage());
1083 new_func_graph->set_segment(func_graph->segment());
1084 return new_func_graph;
1085 }
1086 } // namespace mindspore
1087