1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/func_graph_cloner.h"
18
19 #include <algorithm>
20
21 #include "ir/manager.h"
22 #include "ir/param_info.h"
23 #include "base/core_ops.h"
24 #include "utils/convert_utils_base.h"
25 #include "utils/log_adapter.h"
26 #include "utils/profile.h"
27 #include "utils/ms_context.h"
28 #include "ir/graph_utils.h"
29 #include "utils/parallel_node_check.h"
30
31 // namespace to support intermediate representation definition
32 namespace mindspore {
Cloner(const FuncGraphVector & func_graphs,bool clone_all_valuenodes,bool clone_all_child_graphs,bool clone_all_used_graphs,const TraceInfoPtr & relation,const TraceInfoPtr & target_relation)33 Cloner::Cloner(const FuncGraphVector &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
34 bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
35 : clone_all_valuenodes_(clone_all_valuenodes),
36 clone_all_child_graphs_(clone_all_child_graphs),
37 clone_all_used_graphs_(clone_all_used_graphs),
38 relation_(relation),
39 target_relation_(target_relation == nullptr ? relation : target_relation) {
40 for (auto &func_graph : func_graphs) {
41 AddClone(func_graph);
42 }
43 scope_ = kDefaultScope;
44 type_ = kBasic;
45 }
46
AddClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & params,CloneType type)47 void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
48 const AnfNodePtrList ¶ms, CloneType type) {
49 if (func_graph != nullptr) {
50 CloneInfo clone = {func_graph, target_func_graph, params};
51 todo_.push_back(clone);
52 type_ = type;
53 }
54 }
55
CloneNode(const AnfNodePtr & node,const FuncGraphPtr & target)56 void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
57 MS_EXCEPTION_IF_NULL(node);
58 if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
59 return;
60 }
61 if (node->isa<Parameter>()) {
62 CloneParameter(node, target);
63 } else if (node->isa<CNode>()) {
64 CloneCNode(node, target);
65 }
66 }
67
CloneParameter(const AnfNodePtr & node,const FuncGraphPtr & target,bool is_add)68 void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
69 MS_EXCEPTION_IF_NULL(node);
70 MS_EXCEPTION_IF_NULL(target);
71 TraceGuard trace_guard(node->debug_info(), relation_);
72 auto new_param = (is_add) ? target->add_parameter() : std::make_shared<Parameter>(target);
73 auto old_param = node->cast<ParameterPtr>();
74 MS_EXCEPTION_IF_NULL(old_param);
75 new_param->set_abstract(old_param->abstract());
76 new_param->set_name(old_param->name());
77 if (old_param->has_default()) {
78 // Default parameter can be shared since it is readonly.
79 new_param->set_default_param(old_param->default_param());
80 }
81 ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
82 new_param->set_scope(scope);
83 repl_node_[node] = new_param;
84 }
85
CloneCNode(const AnfNodePtr & node,const FuncGraphPtr & target)86 void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
87 MS_EXCEPTION_IF_NULL(node);
88 MS_EXCEPTION_IF_NULL(target);
89 TraceGuard trace_guard(node->debug_info(), relation_);
90 CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
91 auto old_node = node->cast<CNodePtr>();
92 new_node->CloneCNodeInfo(old_node);
93 ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
94 new_node->set_scope(scope);
95 repl_node_[old_node] = new_node;
96 nodes_.emplace_back(old_node, new_node);
97 }
98
CloneValueNode(const AnfNodePtr & node)99 void Cloner::CloneValueNode(const AnfNodePtr &node) {
100 MS_EXCEPTION_IF_NULL(node);
101 TraceGuard trace_guard(node->debug_info(), relation_);
102 ValueNodePtr new_const = NewValueNode(GetValueNode(node));
103 ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
104 new_const->set_scope(scope);
105 new_const->set_abstract(node->abstract());
106 new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
107 repl_node_[node] = new_const;
108 }
109
CloneValueNode(const AnfNodePtr & node,const FuncGraphPtr & target)110 void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
111 MS_EXCEPTION_IF_NULL(node);
112 MS_EXCEPTION_IF_NULL(target);
113 TraceGuard trace_guard(node->debug_info(), relation_);
114 ValueNodePtr new_const = NewValueNode(target);
115 ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
116 new_const->set_scope(scope);
117 new_const->set_abstract(node->abstract());
118 new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
119 repl_node_[node] = new_const;
120 }
121
CloneValueNodes(const FuncGraphPtr & func_graph)122 void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
123 MS_EXCEPTION_IF_NULL(func_graph);
124 MS_EXCEPTION_IF_NULL(manager_);
125 if (!clone_all_valuenodes_) {
126 return;
127 }
128 auto &value_nodes = func_graph->value_nodes();
129 for (auto &value_node : value_nodes) {
130 auto old_node = value_node.first;
131 MS_EXCEPTION_IF_NULL(old_node);
132 if (repl_node_.count(old_node) == 0) {
133 CloneValueNode(old_node);
134 }
135 }
136 }
137
AddChildGraphs(const FuncGraphPtr & func_graph)138 void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
139 MS_EXCEPTION_IF_NULL(func_graph);
140 MS_EXCEPTION_IF_NULL(manager_);
141 if (!clone_all_child_graphs_) {
142 return;
143 }
144 auto &scopes = manager_->scopes(func_graph);
145 for (auto &graph : scopes) {
146 if (graph != func_graph) {
147 todo_.push_back({graph, nullptr, {}});
148 }
149 }
150 }
151
AddTotalGraphs(const FuncGraphPtr & func_graph)152 void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
153 MS_EXCEPTION_IF_NULL(func_graph);
154 MS_EXCEPTION_IF_NULL(manager_);
155 if (!clone_all_used_graphs_) {
156 return;
157 }
158 auto &used = func_graph->func_graphs_used();
159 for (auto &fg : used) {
160 todo_.push_back({fg.first, nullptr, {}});
161 }
162 }
163
CloneFuncGraphDefaultValues(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)164 void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
165 MS_EXCEPTION_IF_NULL(func_graph);
166 MS_EXCEPTION_IF_NULL(target_func_graph);
167 for (auto &item : func_graph->parameter_default_value()) {
168 auto nodes = DeepLinkedGraphSearch(item.second);
169 for (auto &node : nodes) {
170 MS_EXCEPTION_IF_NULL(node);
171 if (node->isa<CNode>()) {
172 CloneNode(node, target_func_graph);
173 } else if (node->isa<ValueNode>()) {
174 CloneValueNode(node);
175 }
176 }
177 }
178 }
179
CloneFuncGraphValueNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)180 void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
181 MS_EXCEPTION_IF_NULL(func_graph);
182 MS_EXCEPTION_IF_NULL(target_func_graph);
183 MS_EXCEPTION_IF_NULL(manager_);
184
185 target_func_graph->set_stage(func_graph->stage());
186 auto old_return = func_graph->get_return();
187 if (old_return != nullptr) {
188 auto iter = repl_node_.find(old_return);
189 if (iter == repl_node_.end()) {
190 MS_LOG(EXCEPTION) << "Can't find replicate node for return.";
191 }
192 MS_EXCEPTION_IF_NULL(iter->second);
193 auto return_node = iter->second->cast<CNodePtr>();
194 MS_EXCEPTION_IF_NULL(return_node);
195 target_func_graph->set_return(return_node);
196 }
197
198 auto &cnodes = func_graph->func_graph_cnodes_index();
199 for (auto &cnode : cnodes) {
200 auto parent = cnode.first->first->cast<CNodePtr>();
201 MS_EXCEPTION_IF_NULL(parent);
202 auto valuenode = parent->input(cnode.first->second);
203 CloneValueNode(valuenode, target_func_graph);
204 }
205 }
206
InlineCloneParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & params)207 void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) {
208 MS_EXCEPTION_IF_NULL(func_graph);
209 auto &old_params = func_graph->parameters();
210 if (old_params.size() != params.size()) {
211 MS_EXCEPTION(TypeError) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size()
212 << "]";
213 }
214 for (size_t i = 0; i < old_params.size(); ++i) {
215 repl_node_[old_params[i]] = params[i];
216 }
217 }
218
SetFuncGraphInfo(const FuncGraphPtr & func_graph,FuncGraphPtr * const target_func_graph)219 void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
220 MS_EXCEPTION_IF_NULL(func_graph);
221 MS_EXCEPTION_IF_NULL(target_func_graph);
222 TraceGuard trace_guard(func_graph->debug_info(), target_relation_);
223 *target_func_graph = std::make_shared<FuncGraph>();
224 (*target_func_graph)->set_attrs(func_graph->attrs());
225 (*target_func_graph)->set_transforms(func_graph->transforms());
226 (*target_func_graph)->set_has_vararg(func_graph->has_vararg());
227 (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
228 (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
229 (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
230 (*target_func_graph)->set_is_generate(func_graph->is_generated());
231 (*target_func_graph)->set_stub(func_graph->stub());
232 (*target_func_graph)->set_switch_input(func_graph->switch_input());
233 (*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
234 }
235
CloneParameters(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)236 void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
237 MS_EXCEPTION_IF_NULL(func_graph);
238 MS_EXCEPTION_IF_NULL(target_func_graph);
239 auto ¶ms = func_graph->parameters();
240 for (auto ¶m : params) {
241 CloneParameter(param, target_func_graph, true);
242 }
243 repl_func_graph_[func_graph] = target_func_graph;
244 }
245
GenParameters(const FuncGraphPtr & func_graph)246 void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
247 MS_EXCEPTION_IF_NULL(func_graph);
248 auto &free_vars = manager_->free_variables_total();
249 auto iter = free_vars.find(func_graph);
250 if (iter == free_vars.end()) {
251 return;
252 }
253
254 CloneInfo item = todo_.back();
255 auto lift_top_func_graph = item.origin;
256 for (auto &fv_map : iter->second) {
257 auto &free_var = fv_map.first;
258 if (utils::isa<AnfNodePtr>(free_var)) {
259 auto free_var_node = utils::cast<AnfNodePtr>(free_var);
260 // Don't lift weight parameter to top func_graph.
261 if (func_graph == lift_top_func_graph) {
262 if (free_var_node->isa<Parameter>()) {
263 auto free_var_param = free_var_node->cast<ParameterPtr>();
264 if (free_var_param->has_default()) {
265 MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->ToString()
266 << " for top_func_graph: " << lift_top_func_graph->ToString();
267 continue;
268 }
269 }
270 }
271 MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
272 repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
273 }
274 }
275 }
276
CloneParameter(const ParameterPtr & param,const AnfNodePtr & node)277 void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) {
278 param->set_abstract(node->abstract());
279 if (node->isa<Parameter>()) {
280 ParameterPtr old_param = dyn_cast<Parameter>(node);
281 if (old_param->has_default()) {
282 // Default parameter can be shared since it is readonly.
283 param->set_default_param(old_param->default_param());
284 }
285 param->set_name(old_param->name());
286 }
287 }
288
AddParameter(const FuncGraphPtr & func_graph,const AnfNodePtr & node,bool is_add)289 ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
290 TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
291 ParameterPtr param = std::make_shared<Parameter>(func_graph);
292 CloneParameter(param, node);
293 if (is_add) {
294 func_graph->add_parameter(param);
295 }
296 repl_node_[param] = node;
297 repl_map_node_[func_graph][node] = param;
298 return param;
299 }
300
AddParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & params,AnfNodePtrList * const lift_params,AnfNodePtrList * const input_params)301 void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms,
302 AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
303 AnfNodePtrList parameters;
304 std::unordered_set<AnfNodePtr> old_params;
305 for (auto ¶m : func_graph->parameters()) {
306 auto iter = repl_node_.find(param);
307 if (iter != repl_node_.end()) {
308 (void)old_params.insert(iter->second);
309 parameters.push_back(param);
310 } else {
311 parameters.push_back(AddParameter(func_graph, param, false));
312 (void)old_params.insert(param);
313 }
314 }
315 AnfNodePtr new_param = nullptr;
316 CloneInfo item = todo_.back();
317 auto lift_top_func_graph = item.origin;
318 for (auto ¶m : params) {
319 auto old_param = repl_node_[param];
320 if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
321 repl_node_[old_param] = old_param;
322 repl_map_node_[func_graph][old_param] = old_param;
323 input_params->push_back(old_param);
324 continue;
325 }
326 if (old_params.find(old_param) != old_params.end()) {
327 new_param = repl_map_node_[func_graph][old_param];
328 input_params->push_back(new_param);
329 continue;
330 }
331 if (lift_top_func_graph == func_graph) {
332 // Don't lift parameter from used_graphs to my parameter if I am the top;
333 repl_node_[old_param] = old_param;
334 input_params->push_back(old_param);
335 MS_LOG(DEBUG) << "Bypass param: " << old_param->ToString()
336 << " for top_func_graph: " << lift_top_func_graph->ToString();
337 continue;
338 }
339 new_param = AddParameter(func_graph, old_param, false);
340 parameters.push_back(new_param);
341 lift_params->push_back(new_param);
342 input_params->push_back(new_param);
343 }
344 func_graph->set_parameters(parameters);
345 }
346
347 namespace {
FilterMonadInput(const AnfNodePtrList & old_inputs,AnfNodePtrList * new_inputs,AnfNodePtr * possible_u_monad,AnfNodePtr * possible_io_monad)348 void FilterMonadInput(const AnfNodePtrList &old_inputs, AnfNodePtrList *new_inputs, AnfNodePtr *possible_u_monad,
349 AnfNodePtr *possible_io_monad) {
350 AnfNodePtr local_u_monad = nullptr, local_io_monad = nullptr;
351 (void)std::copy_if(old_inputs.cbegin(), old_inputs.cend(), std::back_inserter(*new_inputs),
352 [&local_u_monad, &local_io_monad](const auto &input) -> bool {
353 if (HasAbstractUMonad(input)) {
354 if (local_u_monad != nullptr) {
355 MS_LOG(EXCEPTION)
356 << "Cannot have multiple U Monad in one call, first: " << local_u_monad->ToString()
357 << ", second: " << input->ToString();
358 }
359 local_u_monad = input;
360 return false;
361 }
362 if (HasAbstractIOMonad(input)) {
363 if (local_io_monad != nullptr) {
364 MS_LOG(EXCEPTION)
365 << "Cannot have multiple IO Monad in one call, first: " << local_io_monad->ToString()
366 << ", second: " << input->ToString();
367 }
368 local_io_monad = input;
369 return false;
370 }
371 return true;
372 });
373 *possible_u_monad = local_u_monad;
374 *possible_io_monad = local_io_monad;
375 }
376 } // namespace
377
AddInputs(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodePtrList & params)378 void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
379 const AnfNodePtrList ¶ms) {
380 AnfNodePtr node = nullptr;
381 auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
382 auto iter = repl_func_graph.find(func_graph);
383 if (iter == repl_func_graph.end()) {
384 node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
385 repl_func_graph[func_graph] = node;
386 } else {
387 node = iter->second;
388 }
389 if (node == nullptr || !node->isa<CNode>()) {
390 return;
391 }
392 auto cnode = node->cast<CNodePtr>();
393 AnfNodePtr input_u_monad = nullptr, input_io_monad = nullptr, param_u_monad = nullptr, param_io_monad = nullptr;
394 AnfNodePtrList inputs;
395 std::vector<AnfNodePtr> add_params;
396 FilterMonadInput(cnode->inputs(), &inputs, &input_u_monad, &input_io_monad);
397 FilterMonadInput(params, &add_params, ¶m_u_monad, ¶m_io_monad);
398
399 constexpr auto caller_first_arg_index = 2;
400 for (size_t i = caller_first_arg_index; i < inputs.size(); i++) {
401 auto ret = std::find(add_params.begin(), add_params.end(), inputs[i]);
402 if (ret != add_params.end()) {
403 add_params.erase(ret);
404 }
405 }
406 if (input_u_monad != nullptr && param_u_monad != nullptr && input_u_monad != param_u_monad) {
407 MS_LOG(EXCEPTION) << "Cannot have multiple U Monad in one call, first: " << input_u_monad->ToString()
408 << ", second: " << param_u_monad->ToString();
409 }
410 if (input_io_monad != nullptr && param_io_monad != nullptr && input_io_monad != param_io_monad) {
411 MS_LOG(EXCEPTION) << "Cannot have multiple IO Monad in one call, first: " << input_io_monad->ToString()
412 << ", second: " << param_io_monad->ToString();
413 }
414 (void)std::copy(add_params.begin(), add_params.end(), std::back_inserter(inputs));
415 auto &u_monad = input_u_monad != nullptr ? input_u_monad : param_u_monad;
416 auto &io_monad = input_io_monad != nullptr ? input_io_monad : param_io_monad;
417 if (u_monad != nullptr) {
418 inputs.push_back(u_monad);
419 }
420 if (io_monad != nullptr) {
421 inputs.push_back(io_monad);
422 }
423 cnode->set_inputs(inputs);
424 OrderParameters(func_graph, inputs, caller_first_arg_index);
425 }
426
OrderParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs,size_t arg_start_index)427 void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs, size_t arg_start_index) {
428 std::unordered_set<AnfNodePtr> old_params;
429 for (auto ¶m : func_graph->parameters()) {
430 (void)old_params.insert(repl_node_[param]);
431 }
432 std::unordered_set<AnfNodePtr> new_params;
433 AnfNodePtrList parameters;
434 // Ignore the 1st and 2nd param of inputs(such as. partial graph)
435 for (size_t i = arg_start_index; i < inputs.size(); ++i) {
436 auto input = inputs[i];
437 auto param = repl_node_[input];
438 if (old_params.find(param) != old_params.end()) {
439 auto new_param = repl_map_node_[func_graph][param];
440 parameters.push_back(new_param);
441 (void)new_params.insert(new_param);
442 }
443 }
444 for (auto ¶m : func_graph->parameters()) {
445 if (new_params.find(param) == new_params.end()) {
446 parameters.push_back(param);
447 }
448 }
449 func_graph->set_parameters(parameters);
450 }
451
SetEdges(const FuncGraphPtr & func_graph,FuncGraphTransaction * tx)452 void Cloner::SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx) {
453 MS_EXCEPTION_IF_NULL(func_graph);
454 for (auto &node : func_graph->nodes()) {
455 if (node == nullptr) {
456 continue;
457 }
458 // Only cnode needed to be handled
459 if (!node->isa<CNode>()) {
460 continue;
461 }
462 auto cnode = node->cast<CNodePtr>();
463 auto &inputs = cnode->inputs();
464 for (size_t i = 0; i < inputs.size(); i++) {
465 auto &input = inputs[i];
466 if (IsValueNode<FuncGraph>(input)) {
467 auto graph = GetValueNode<FuncGraphPtr>(input);
468 auto &repl_func_graph = repl_map_func_graph_[func_graph];
469 if (repl_func_graph.find(graph) != repl_func_graph.end()) {
470 tx->SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
471 }
472 } else {
473 auto &repl_node = repl_map_node_[func_graph];
474 if (repl_node.find(input) != repl_node.end()) {
475 tx->SetEdge(cnode, SizeToInt(i), repl_node[input]);
476 }
477 }
478 }
479 }
480 }
481
LiftParameters(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodePtrList & params)482 void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
483 const AnfNodePtrList ¶ms) {
484 AnfNodePtrList lift_params;
485 AnfNodePtrList input_params;
486 AddParameters(func_graph_user, params, &lift_params, &input_params);
487 AddInputs(func_graph_user, func_graph, input_params);
488 if (lift_params.empty()) {
489 return;
490 }
491 for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
492 LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
493 }
494 }
495
Lift(const std::vector<FuncGraphPtr> & sorted)496 void Cloner::Lift(const std::vector<FuncGraphPtr> &sorted) {
497 // lift inner graph first
498 for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
499 auto func_graph = *r_iter;
500 auto iter = repl_func_graph_params_.find(func_graph);
501 if (iter != repl_func_graph_params_.end()) {
502 auto ¶ms = iter->second;
503 for (auto &cnode : func_graph->func_graph_cnodes_index()) {
504 LiftParameters(cnode.first->first->func_graph(), func_graph, params);
505 }
506 }
507 }
508 }
509
LiftParameters(const FuncGraphPtr & lift_top_func_graph)510 void Cloner::LiftParameters(const FuncGraphPtr &lift_top_func_graph) {
511 MS_EXCEPTION_IF_NULL(manager_);
512 auto tx = manager_->Transact();
513 const auto &func_graphs = BroadFirstSearchGraphUsed(lift_top_func_graph);
514 for (auto &func_graph : func_graphs) {
515 GenParameters(func_graph);
516 }
517 Lift(func_graphs);
518 for (auto &func_graph : func_graphs) {
519 SetEdges(func_graph, &tx);
520 }
521 tx.Commit();
522 }
523
CheckStatus(const FuncGraphPtr & func_graph,bool is_inline)524 bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
525 MS_EXCEPTION_IF_NULL(func_graph);
526 // Make sure only inline once
527 if (status_.count(func_graph) != 0) {
528 if (is_inline == status_[func_graph]) {
529 return false;
530 }
531 if (clone_all_used_graphs_) {
532 MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False.";
533 return false;
534 }
535 }
536 return true;
537 }
538
CloneAllNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)539 void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
540 MS_EXCEPTION_IF_NULL(func_graph);
541 MS_EXCEPTION_IF_NULL(target_func_graph);
542 MS_EXCEPTION_IF_NULL(manager_);
543 const AnfNodeSet &nodes = func_graph->nodes();
544 for (auto &node : nodes) {
545 CloneNode(node, target_func_graph);
546 }
547 // Only func_graph is inlined, it cannot be found in repl;
548 if (repl_func_graph_.find(func_graph) != repl_func_graph_.end()) {
549 CloneOrderList(func_graph, target_func_graph);
550 }
551 }
552
CloneOrderList(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)553 void Cloner::CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
554 for (auto &cnode : func_graph->order_list()) {
555 auto it = repl_node_.find(cnode);
556 if (it == repl_node_.end()) {
557 // For cnode which generated in Analyze phase, it cannot got from nodes API of func_graph,
558 // so it cannot be cloned in normal Clone API.
559 // If we ignore it, the order will be lost.
560 // Therefore we put this old node as placeholder to the order list of target func_graph to
561 // keep the order.
562 // It may be replaced in ProgramSpecialize.
563 // If this disconnected node is not used in target func_graph, it will be cleared after
564 // ProgramSpecialize;
565 target_func_graph->AppendOrderList(cnode);
566 continue;
567 }
568 auto repl_cnode = dyn_cast<CNode>(it->second);
569 if (repl_cnode) {
570 target_func_graph->AppendOrderList(repl_cnode);
571 }
572 }
573 }
574
Run()575 void Cloner::Run() {
576 if (todo_.empty()) {
577 return;
578 }
579
580 if (type_ < kLifting) {
581 // Basic and Inline Clone
582 FuncGraphVector func_graphs;
583 (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
584 [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
585 manager_ = Manage(func_graphs, false);
586 CloneNodes();
587 LinkEdges();
588 SetDefaults();
589 } else {
590 // Lifting Clone
591 CloneInfo item = todo_.back();
592 manager_ = Manage(item.origin);
593 LiftParameters(item.origin);
594 }
595 }
596
CloneNodes()597 void Cloner::CloneNodes() {
598 while (!todo_.empty()) {
599 CloneInfo item = todo_.back();
600 todo_.pop_back();
601
602 bool is_inline = (item.target != nullptr);
603 FuncGraphPtr func_graph = item.origin;
604 FuncGraphPtr target_func_graph = item.target;
605 (void)graph_set_.insert(func_graph);
606
607 if (!CheckStatus(func_graph, is_inline)) {
608 continue;
609 }
610
611 if (is_inline) {
612 InlineCloneParameters(func_graph, item.params);
613 CloneAllNodes(func_graph, target_func_graph);
614 } else {
615 SetFuncGraphInfo(func_graph, &target_func_graph);
616 CloneParameters(func_graph, target_func_graph);
617 CloneAllNodes(func_graph, target_func_graph);
618 CloneFuncGraphValueNodes(func_graph, target_func_graph);
619 CloneFuncGraphDefaultValues(func_graph, target_func_graph);
620 }
621
622 CloneValueNodes(func_graph);
623 AddChildGraphs(func_graph);
624 AddTotalGraphs(func_graph);
625 status_[func_graph] = is_inline;
626 }
627 }
628
LinkEdges()629 void Cloner::LinkEdges() {
630 for (auto &node_pair : nodes_) {
631 CNodePtr old_node = node_pair.first;
632 CNodePtr new_node = node_pair.second;
633 MS_EXCEPTION_IF_NULL(old_node);
634 MS_EXCEPTION_IF_NULL(new_node);
635 for (auto &input : old_node->inputs()) {
636 auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
637 new_node->add_input(new_input);
638 }
639 }
640 }
641
642 // For the graphs cloned, update its default value map to the cloned nodes
SetDefaults()643 void Cloner::SetDefaults() {
644 for (auto &item : graph_set_) {
645 MS_EXCEPTION_IF_NULL(item);
646 if (repl_func_graph_.count(item) != 0) {
647 for (auto ¶m_def : item->parameter_default_value()) {
648 MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
649 if (repl_node_.count(param_def.second) != 0) {
650 repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
651 } else {
652 repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second);
653 }
654 }
655 }
656 }
657 }
658
CloneDisconnected(const AnfNodePtr & root)659 AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
660 MS_EXCEPTION_IF_NULL(root);
661 if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
662 MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
663 }
664 CloneNode(root, repl_func_graph_[root->func_graph()]);
665 auto iter = repl_node_.find(root);
666 if (iter != repl_node_.end()) {
667 return iter->second;
668 }
669 MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
670 }
671
operator [](const AnfNodePtr & node)672 AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
673 #ifdef ENABLE_PROFILE
674 double time = GetTime();
675 #endif
676 Run();
677 #ifdef ENABLE_PROFILE
678 MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time);
679 #endif
680 return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
681 }
682
operator [](const FuncGraphPtr & func_graph)683 FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
684 #ifdef ENABLE_PROFILE
685 double time = GetTime();
686 #endif
687 Run();
688 #ifdef ENABLE_PROFILE
689 MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time);
690 #endif
691 return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
692 }
693
BasicClone(const FuncGraphPtr & func_graph,bool clone_value_nodes)694 FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes) {
695 MS_EXCEPTION_IF_NULL(func_graph);
696 Cloner cloner({func_graph}, clone_value_nodes, true, true, std::make_shared<TraceCopy>(), nullptr);
697 return cloner[func_graph];
698 }
699
InlineClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & func_graph_args,const ScopePtr & scope)700 AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
701 const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
702 MS_EXCEPTION_IF_NULL(func_graph);
703 MS_EXCEPTION_IF_NULL(target_func_graph);
704 Cloner cloner({}, false);
705 if (scope != nullptr) {
706 cloner.set_scope(scope);
707 }
708 cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
709 return cloner[func_graph->output()];
710 }
711
LiftingClone(const FuncGraphPtr & func_graph)712 FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
713 MS_EXCEPTION_IF_NULL(func_graph);
714 Cloner cloner({}, false);
715 cloner.AddClone(func_graph, nullptr, {}, kLifting);
716 return cloner[func_graph];
717 }
718
SpecializerClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)719 ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
720 MS_EXCEPTION_IF_NULL(func_graph);
721 FuncGraphVector func_graphs = {func_graph};
722 ClonerPtr cloner =
723 std::make_shared<Cloner>(func_graphs, false, false, false, std::make_shared<TraceCopy>(), relation);
724 #ifdef ENABLE_PROFILE
725 double time = GetTime();
726 #endif
727 cloner->Run();
728 #ifdef ENABLE_PROFILE
729 MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time);
730 #endif
731 return cloner;
732 }
733
TransformableClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)734 FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
735 MS_EXCEPTION_IF_NULL(func_graph);
736 TraceGuard guard(func_graph->debug_info(), relation);
737 auto new_func_graph = std::make_shared<FuncGraph>();
738
739 auto ¶meters = func_graph->parameters();
740 (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void {
741 MS_EXCEPTION_IF_NULL(param);
742 TraceGuard trace_guard(std::make_shared<TraceCopy>(param->debug_info()));
743 (void)new_func_graph->add_parameter()->set_abstract(param->abstract());
744 });
745
746 Cloner cloner = Cloner();
747 cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters());
748 AnfNodePtr output = cloner[func_graph->output()];
749 new_func_graph->set_output(output);
750 new_func_graph->set_has_vararg(func_graph->has_vararg());
751 new_func_graph->set_has_kwarg(func_graph->has_kwarg());
752 new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
753 new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
754 new_func_graph->set_is_generate(func_graph->is_generated());
755 new_func_graph->set_stub(func_graph->stub());
756 new_func_graph->set_switch_input(func_graph->switch_input());
757 new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
758 for (auto &item : func_graph->parameter_default_value()) {
759 new_func_graph->set_param_default_value(item.first, cloner[item.second]);
760 }
761 if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
762 new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
763 }
764 if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
765 new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
766 }
767 new_func_graph->set_stage(func_graph->stage());
768
769 return new_func_graph;
770 }
771 } // namespace mindspore
772