• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/func_graph_cloner.h"
18 
19 #include <algorithm>
20 
21 #include "ir/manager.h"
22 #include "ir/param_info.h"
23 #include "base/core_ops.h"
24 #include "utils/convert_utils_base.h"
25 #include "utils/log_adapter.h"
26 #include "utils/profile.h"
27 #include "utils/ms_context.h"
28 #include "ir/graph_utils.h"
29 #include "utils/parallel_node_check.h"
30 
31 // namespace to support intermediate representation definition
32 namespace mindspore {
Cloner(const FuncGraphVector & func_graphs,bool clone_all_valuenodes,bool clone_all_child_graphs,bool clone_all_used_graphs,const TraceInfoPtr & relation,const TraceInfoPtr & target_relation)33 Cloner::Cloner(const FuncGraphVector &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs,
34                bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation)
35     : clone_all_valuenodes_(clone_all_valuenodes),
36       clone_all_child_graphs_(clone_all_child_graphs),
37       clone_all_used_graphs_(clone_all_used_graphs),
38       relation_(relation),
39       target_relation_(target_relation == nullptr ? relation : target_relation) {
40   for (auto &func_graph : func_graphs) {
41     AddClone(func_graph);
42   }
43   scope_ = kDefaultScope;
44   type_ = kBasic;
45 }
46 
AddClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & params,CloneType type)47 void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
48                       const AnfNodePtrList &params, CloneType type) {
49   if (func_graph != nullptr) {
50     CloneInfo clone = {func_graph, target_func_graph, params};
51     todo_.push_back(clone);
52     type_ = type;
53   }
54 }
55 
CloneNode(const AnfNodePtr & node,const FuncGraphPtr & target)56 void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
57   MS_EXCEPTION_IF_NULL(node);
58   if (repl_node_.find(node) != repl_node_.end() || node->isa<ValueNode>()) {
59     return;
60   }
61   if (node->isa<Parameter>()) {
62     CloneParameter(node, target);
63   } else if (node->isa<CNode>()) {
64     CloneCNode(node, target);
65   }
66 }
67 
CloneParameter(const AnfNodePtr & node,const FuncGraphPtr & target,bool is_add)68 void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) {
69   MS_EXCEPTION_IF_NULL(node);
70   MS_EXCEPTION_IF_NULL(target);
71   TraceGuard trace_guard(node->debug_info(), relation_);
72   auto new_param = (is_add) ? target->add_parameter() : std::make_shared<Parameter>(target);
73   auto old_param = node->cast<ParameterPtr>();
74   MS_EXCEPTION_IF_NULL(old_param);
75   new_param->set_abstract(old_param->abstract());
76   new_param->set_name(old_param->name());
77   if (old_param->has_default()) {
78     // Default parameter can be shared since it is readonly.
79     new_param->set_default_param(old_param->default_param());
80   }
81   ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
82   new_param->set_scope(scope);
83   repl_node_[node] = new_param;
84 }
85 
CloneCNode(const AnfNodePtr & node,const FuncGraphPtr & target)86 void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
87   MS_EXCEPTION_IF_NULL(node);
88   MS_EXCEPTION_IF_NULL(target);
89   TraceGuard trace_guard(node->debug_info(), relation_);
90   CNodePtr new_node = std::make_shared<CNode>(AnfNodePtrList{}, target);
91   auto old_node = node->cast<CNodePtr>();
92   new_node->CloneCNodeInfo(old_node);
93   ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
94   new_node->set_scope(scope);
95   repl_node_[old_node] = new_node;
96   nodes_.emplace_back(old_node, new_node);
97 }
98 
CloneValueNode(const AnfNodePtr & node)99 void Cloner::CloneValueNode(const AnfNodePtr &node) {
100   MS_EXCEPTION_IF_NULL(node);
101   TraceGuard trace_guard(node->debug_info(), relation_);
102   ValueNodePtr new_const = NewValueNode(GetValueNode(node));
103   ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
104   new_const->set_scope(scope);
105   new_const->set_abstract(node->abstract());
106   new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
107   repl_node_[node] = new_const;
108 }
109 
CloneValueNode(const AnfNodePtr & node,const FuncGraphPtr & target)110 void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) {
111   MS_EXCEPTION_IF_NULL(node);
112   MS_EXCEPTION_IF_NULL(target);
113   TraceGuard trace_guard(node->debug_info(), relation_);
114   ValueNodePtr new_const = NewValueNode(target);
115   ScopePtr scope = ((node->scope() == kDefaultScope) && (this->scope() != nullptr)) ? this->scope() : node->scope();
116   new_const->set_scope(scope);
117   new_const->set_abstract(node->abstract());
118   new_const->set_has_new_value(node->cast<ValueNodePtr>()->has_new_value());
119   repl_node_[node] = new_const;
120 }
121 
CloneValueNodes(const FuncGraphPtr & func_graph)122 void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) {
123   MS_EXCEPTION_IF_NULL(func_graph);
124   MS_EXCEPTION_IF_NULL(manager_);
125   if (!clone_all_valuenodes_) {
126     return;
127   }
128   auto &value_nodes = func_graph->value_nodes();
129   for (auto &value_node : value_nodes) {
130     auto old_node = value_node.first;
131     MS_EXCEPTION_IF_NULL(old_node);
132     if (repl_node_.count(old_node) == 0) {
133       CloneValueNode(old_node);
134     }
135   }
136 }
137 
AddChildGraphs(const FuncGraphPtr & func_graph)138 void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) {
139   MS_EXCEPTION_IF_NULL(func_graph);
140   MS_EXCEPTION_IF_NULL(manager_);
141   if (!clone_all_child_graphs_) {
142     return;
143   }
144   auto &scopes = manager_->scopes(func_graph);
145   for (auto &graph : scopes) {
146     if (graph != func_graph) {
147       todo_.push_back({graph, nullptr, {}});
148     }
149   }
150 }
151 
AddTotalGraphs(const FuncGraphPtr & func_graph)152 void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
153   MS_EXCEPTION_IF_NULL(func_graph);
154   MS_EXCEPTION_IF_NULL(manager_);
155   if (!clone_all_used_graphs_) {
156     return;
157   }
158   auto &used = func_graph->func_graphs_used();
159   for (auto &fg : used) {
160     todo_.push_back({fg.first, nullptr, {}});
161   }
162 }
163 
CloneFuncGraphDefaultValues(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)164 void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
165   MS_EXCEPTION_IF_NULL(func_graph);
166   MS_EXCEPTION_IF_NULL(target_func_graph);
167   for (auto &item : func_graph->parameter_default_value()) {
168     auto nodes = DeepLinkedGraphSearch(item.second);
169     for (auto &node : nodes) {
170       MS_EXCEPTION_IF_NULL(node);
171       if (node->isa<CNode>()) {
172         CloneNode(node, target_func_graph);
173       } else if (node->isa<ValueNode>()) {
174         CloneValueNode(node);
175       }
176     }
177   }
178 }
179 
CloneFuncGraphValueNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)180 void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
181   MS_EXCEPTION_IF_NULL(func_graph);
182   MS_EXCEPTION_IF_NULL(target_func_graph);
183   MS_EXCEPTION_IF_NULL(manager_);
184 
185   target_func_graph->set_stage(func_graph->stage());
186   auto old_return = func_graph->get_return();
187   if (old_return != nullptr) {
188     auto iter = repl_node_.find(old_return);
189     if (iter == repl_node_.end()) {
190       MS_LOG(EXCEPTION) << "Can't find replicate node for return.";
191     }
192     MS_EXCEPTION_IF_NULL(iter->second);
193     auto return_node = iter->second->cast<CNodePtr>();
194     MS_EXCEPTION_IF_NULL(return_node);
195     target_func_graph->set_return(return_node);
196   }
197 
198   auto &cnodes = func_graph->func_graph_cnodes_index();
199   for (auto &cnode : cnodes) {
200     auto parent = cnode.first->first->cast<CNodePtr>();
201     MS_EXCEPTION_IF_NULL(parent);
202     auto valuenode = parent->input(cnode.first->second);
203     CloneValueNode(valuenode, target_func_graph);
204   }
205 }
206 
InlineCloneParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & params)207 void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params) {
208   MS_EXCEPTION_IF_NULL(func_graph);
209   auto &old_params = func_graph->parameters();
210   if (old_params.size() != params.size()) {
211     MS_EXCEPTION(TypeError) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size()
212                             << "]";
213   }
214   for (size_t i = 0; i < old_params.size(); ++i) {
215     repl_node_[old_params[i]] = params[i];
216   }
217 }
218 
SetFuncGraphInfo(const FuncGraphPtr & func_graph,FuncGraphPtr * const target_func_graph)219 void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) {
220   MS_EXCEPTION_IF_NULL(func_graph);
221   MS_EXCEPTION_IF_NULL(target_func_graph);
222   TraceGuard trace_guard(func_graph->debug_info(), target_relation_);
223   *target_func_graph = std::make_shared<FuncGraph>();
224   (*target_func_graph)->set_attrs(func_graph->attrs());
225   (*target_func_graph)->set_transforms(func_graph->transforms());
226   (*target_func_graph)->set_has_vararg(func_graph->has_vararg());
227   (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg());
228   (*target_func_graph)->set_kwonlyargs_count(func_graph->kwonlyargs_count());
229   (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
230   (*target_func_graph)->set_is_generate(func_graph->is_generated());
231   (*target_func_graph)->set_stub(func_graph->stub());
232   (*target_func_graph)->set_switch_input(func_graph->switch_input());
233   (*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
234 }
235 
CloneParameters(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)236 void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
237   MS_EXCEPTION_IF_NULL(func_graph);
238   MS_EXCEPTION_IF_NULL(target_func_graph);
239   auto &params = func_graph->parameters();
240   for (auto &param : params) {
241     CloneParameter(param, target_func_graph, true);
242   }
243   repl_func_graph_[func_graph] = target_func_graph;
244 }
245 
GenParameters(const FuncGraphPtr & func_graph)246 void Cloner::GenParameters(const FuncGraphPtr &func_graph) {
247   MS_EXCEPTION_IF_NULL(func_graph);
248   auto &free_vars = manager_->free_variables_total();
249   auto iter = free_vars.find(func_graph);
250   if (iter == free_vars.end()) {
251     return;
252   }
253 
254   CloneInfo item = todo_.back();
255   auto lift_top_func_graph = item.origin;
256   for (auto &fv_map : iter->second) {
257     auto &free_var = fv_map.first;
258     if (utils::isa<AnfNodePtr>(free_var)) {
259       auto free_var_node = utils::cast<AnfNodePtr>(free_var);
260       // Don't lift weight parameter to top func_graph.
261       if (func_graph == lift_top_func_graph) {
262         if (free_var_node->isa<Parameter>()) {
263           auto free_var_param = free_var_node->cast<ParameterPtr>();
264           if (free_var_param->has_default()) {
265             MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->ToString()
266                           << " for top_func_graph: " << lift_top_func_graph->ToString();
267             continue;
268           }
269         }
270       }
271       MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString();
272       repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var)));
273     }
274   }
275 }
276 
CloneParameter(const ParameterPtr & param,const AnfNodePtr & node)277 void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
278   param->set_abstract(node->abstract());
279   if (node->isa<Parameter>()) {
280     ParameterPtr old_param = dyn_cast<Parameter>(node);
281     if (old_param->has_default()) {
282       // Default parameter can be shared since it is readonly.
283       param->set_default_param(old_param->default_param());
284     }
285     param->set_name(old_param->name());
286   }
287 }
288 
AddParameter(const FuncGraphPtr & func_graph,const AnfNodePtr & node,bool is_add)289 ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) {
290   TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
291   ParameterPtr param = std::make_shared<Parameter>(func_graph);
292   CloneParameter(param, node);
293   if (is_add) {
294     func_graph->add_parameter(param);
295   }
296   repl_node_[param] = node;
297   repl_map_node_[func_graph][node] = param;
298   return param;
299 }
300 
AddParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & params,AnfNodePtrList * const lift_params,AnfNodePtrList * const input_params)301 void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &params,
302                            AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) {
303   AnfNodePtrList parameters;
304   std::unordered_set<AnfNodePtr> old_params;
305   for (auto &param : func_graph->parameters()) {
306     auto iter = repl_node_.find(param);
307     if (iter != repl_node_.end()) {
308       (void)old_params.insert(iter->second);
309       parameters.push_back(param);
310     } else {
311       parameters.push_back(AddParameter(func_graph, param, false));
312       (void)old_params.insert(param);
313     }
314   }
315   AnfNodePtr new_param = nullptr;
316   CloneInfo item = todo_.back();
317   auto lift_top_func_graph = item.origin;
318   for (auto &param : params) {
319     auto old_param = repl_node_[param];
320     if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) {
321       repl_node_[old_param] = old_param;
322       repl_map_node_[func_graph][old_param] = old_param;
323       input_params->push_back(old_param);
324       continue;
325     }
326     if (old_params.find(old_param) != old_params.end()) {
327       new_param = repl_map_node_[func_graph][old_param];
328       input_params->push_back(new_param);
329       continue;
330     }
331     if (lift_top_func_graph == func_graph) {
332       // Don't lift parameter from used_graphs to my parameter if I am the top;
333       repl_node_[old_param] = old_param;
334       input_params->push_back(old_param);
335       MS_LOG(DEBUG) << "Bypass param: " << old_param->ToString()
336                     << " for top_func_graph: " << lift_top_func_graph->ToString();
337       continue;
338     }
339     new_param = AddParameter(func_graph, old_param, false);
340     parameters.push_back(new_param);
341     lift_params->push_back(new_param);
342     input_params->push_back(new_param);
343   }
344   func_graph->set_parameters(parameters);
345 }
346 
347 namespace {
FilterMonadInput(const AnfNodePtrList & old_inputs,AnfNodePtrList * new_inputs,AnfNodePtr * possible_u_monad,AnfNodePtr * possible_io_monad)348 void FilterMonadInput(const AnfNodePtrList &old_inputs, AnfNodePtrList *new_inputs, AnfNodePtr *possible_u_monad,
349                       AnfNodePtr *possible_io_monad) {
350   AnfNodePtr local_u_monad = nullptr, local_io_monad = nullptr;
351   (void)std::copy_if(old_inputs.cbegin(), old_inputs.cend(), std::back_inserter(*new_inputs),
352                      [&local_u_monad, &local_io_monad](const auto &input) -> bool {
353                        if (HasAbstractUMonad(input)) {
354                          if (local_u_monad != nullptr) {
355                            MS_LOG(EXCEPTION)
356                              << "Cannot have multiple U Monad in one call, first: " << local_u_monad->ToString()
357                              << ", second: " << input->ToString();
358                          }
359                          local_u_monad = input;
360                          return false;
361                        }
362                        if (HasAbstractIOMonad(input)) {
363                          if (local_io_monad != nullptr) {
364                            MS_LOG(EXCEPTION)
365                              << "Cannot have multiple IO Monad in one call, first: " << local_io_monad->ToString()
366                              << ", second: " << input->ToString();
367                          }
368                          local_io_monad = input;
369                          return false;
370                        }
371                        return true;
372                      });
373   *possible_u_monad = local_u_monad;
374   *possible_io_monad = local_io_monad;
375 }
376 }  // namespace
377 
AddInputs(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodePtrList & params)378 void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
379                        const AnfNodePtrList &params) {
380   AnfNodePtr node = nullptr;
381   auto &repl_func_graph = repl_map_func_graph_[func_graph_user];
382   auto iter = repl_func_graph.find(func_graph);
383   if (iter == repl_func_graph.end()) {
384     node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)});
385     repl_func_graph[func_graph] = node;
386   } else {
387     node = iter->second;
388   }
389   if (node == nullptr || !node->isa<CNode>()) {
390     return;
391   }
392   auto cnode = node->cast<CNodePtr>();
393   AnfNodePtr input_u_monad = nullptr, input_io_monad = nullptr, param_u_monad = nullptr, param_io_monad = nullptr;
394   AnfNodePtrList inputs;
395   std::vector<AnfNodePtr> add_params;
396   FilterMonadInput(cnode->inputs(), &inputs, &input_u_monad, &input_io_monad);
397   FilterMonadInput(params, &add_params, &param_u_monad, &param_io_monad);
398 
399   constexpr auto caller_first_arg_index = 2;
400   for (size_t i = caller_first_arg_index; i < inputs.size(); i++) {
401     auto ret = std::find(add_params.begin(), add_params.end(), inputs[i]);
402     if (ret != add_params.end()) {
403       add_params.erase(ret);
404     }
405   }
406   if (input_u_monad != nullptr && param_u_monad != nullptr && input_u_monad != param_u_monad) {
407     MS_LOG(EXCEPTION) << "Cannot have multiple U Monad in one call, first: " << input_u_monad->ToString()
408                       << ", second: " << param_u_monad->ToString();
409   }
410   if (input_io_monad != nullptr && param_io_monad != nullptr && input_io_monad != param_io_monad) {
411     MS_LOG(EXCEPTION) << "Cannot have multiple IO Monad in one call, first: " << input_io_monad->ToString()
412                       << ", second: " << param_io_monad->ToString();
413   }
414   (void)std::copy(add_params.begin(), add_params.end(), std::back_inserter(inputs));
415   auto &u_monad = input_u_monad != nullptr ? input_u_monad : param_u_monad;
416   auto &io_monad = input_io_monad != nullptr ? input_io_monad : param_io_monad;
417   if (u_monad != nullptr) {
418     inputs.push_back(u_monad);
419   }
420   if (io_monad != nullptr) {
421     inputs.push_back(io_monad);
422   }
423   cnode->set_inputs(inputs);
424   OrderParameters(func_graph, inputs, caller_first_arg_index);
425 }
426 
OrderParameters(const FuncGraphPtr & func_graph,const AnfNodePtrList & inputs,size_t arg_start_index)427 void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs, size_t arg_start_index) {
428   std::unordered_set<AnfNodePtr> old_params;
429   for (auto &param : func_graph->parameters()) {
430     (void)old_params.insert(repl_node_[param]);
431   }
432   std::unordered_set<AnfNodePtr> new_params;
433   AnfNodePtrList parameters;
434   // Ignore the 1st and 2nd param of inputs(such as. partial graph)
435   for (size_t i = arg_start_index; i < inputs.size(); ++i) {
436     auto input = inputs[i];
437     auto param = repl_node_[input];
438     if (old_params.find(param) != old_params.end()) {
439       auto new_param = repl_map_node_[func_graph][param];
440       parameters.push_back(new_param);
441       (void)new_params.insert(new_param);
442     }
443   }
444   for (auto &param : func_graph->parameters()) {
445     if (new_params.find(param) == new_params.end()) {
446       parameters.push_back(param);
447     }
448   }
449   func_graph->set_parameters(parameters);
450 }
451 
SetEdges(const FuncGraphPtr & func_graph,FuncGraphTransaction * tx)452 void Cloner::SetEdges(const FuncGraphPtr &func_graph, FuncGraphTransaction *tx) {
453   MS_EXCEPTION_IF_NULL(func_graph);
454   for (auto &node : func_graph->nodes()) {
455     if (node == nullptr) {
456       continue;
457     }
458     // Only cnode needed to be handled
459     if (!node->isa<CNode>()) {
460       continue;
461     }
462     auto cnode = node->cast<CNodePtr>();
463     auto &inputs = cnode->inputs();
464     for (size_t i = 0; i < inputs.size(); i++) {
465       auto &input = inputs[i];
466       if (IsValueNode<FuncGraph>(input)) {
467         auto graph = GetValueNode<FuncGraphPtr>(input);
468         auto &repl_func_graph = repl_map_func_graph_[func_graph];
469         if (repl_func_graph.find(graph) != repl_func_graph.end()) {
470           tx->SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]);
471         }
472       } else {
473         auto &repl_node = repl_map_node_[func_graph];
474         if (repl_node.find(input) != repl_node.end()) {
475           tx->SetEdge(cnode, SizeToInt(i), repl_node[input]);
476         }
477       }
478     }
479   }
480 }
481 
LiftParameters(const FuncGraphPtr & func_graph_user,const FuncGraphPtr & func_graph,const AnfNodePtrList & params)482 void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph,
483                             const AnfNodePtrList &params) {
484   AnfNodePtrList lift_params;
485   AnfNodePtrList input_params;
486   AddParameters(func_graph_user, params, &lift_params, &input_params);
487   AddInputs(func_graph_user, func_graph, input_params);
488   if (lift_params.empty()) {
489     return;
490   }
491   for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
492     LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
493   }
494 }
495 
Lift(const std::vector<FuncGraphPtr> & sorted)496 void Cloner::Lift(const std::vector<FuncGraphPtr> &sorted) {
497   // lift inner graph first
498   for (auto r_iter = sorted.rbegin(); r_iter != sorted.rend(); ++r_iter) {
499     auto func_graph = *r_iter;
500     auto iter = repl_func_graph_params_.find(func_graph);
501     if (iter != repl_func_graph_params_.end()) {
502       auto &params = iter->second;
503       for (auto &cnode : func_graph->func_graph_cnodes_index()) {
504         LiftParameters(cnode.first->first->func_graph(), func_graph, params);
505       }
506     }
507   }
508 }
509 
LiftParameters(const FuncGraphPtr & lift_top_func_graph)510 void Cloner::LiftParameters(const FuncGraphPtr &lift_top_func_graph) {
511   MS_EXCEPTION_IF_NULL(manager_);
512   auto tx = manager_->Transact();
513   const auto &func_graphs = BroadFirstSearchGraphUsed(lift_top_func_graph);
514   for (auto &func_graph : func_graphs) {
515     GenParameters(func_graph);
516   }
517   Lift(func_graphs);
518   for (auto &func_graph : func_graphs) {
519     SetEdges(func_graph, &tx);
520   }
521   tx.Commit();
522 }
523 
CheckStatus(const FuncGraphPtr & func_graph,bool is_inline)524 bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) {
525   MS_EXCEPTION_IF_NULL(func_graph);
526   // Make sure only inline once
527   if (status_.count(func_graph) != 0) {
528     if (is_inline == status_[func_graph]) {
529       return false;
530     }
531     if (clone_all_used_graphs_) {
532       MS_LOG(ERROR) << "Try setting the `clone_all_used_graphs` option to False.";
533       return false;
534     }
535   }
536   return true;
537 }
538 
CloneAllNodes(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)539 void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
540   MS_EXCEPTION_IF_NULL(func_graph);
541   MS_EXCEPTION_IF_NULL(target_func_graph);
542   MS_EXCEPTION_IF_NULL(manager_);
543   const AnfNodeSet &nodes = func_graph->nodes();
544   for (auto &node : nodes) {
545     CloneNode(node, target_func_graph);
546   }
547   // Only func_graph is inlined, it cannot be found in repl;
548   if (repl_func_graph_.find(func_graph) != repl_func_graph_.end()) {
549     CloneOrderList(func_graph, target_func_graph);
550   }
551 }
552 
CloneOrderList(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph)553 void Cloner::CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) {
554   for (auto &cnode : func_graph->order_list()) {
555     auto it = repl_node_.find(cnode);
556     if (it == repl_node_.end()) {
557       // For cnode which generated in Analyze phase, it cannot got from nodes API of func_graph,
558       // so it cannot be cloned in normal Clone API.
559       // If we ignore it, the order will be lost.
560       // Therefore we put this old node as placeholder to the order list of target func_graph to
561       // keep the order.
562       // It may be replaced in ProgramSpecialize.
563       // If this disconnected node is not used in target func_graph, it will be cleared after
564       // ProgramSpecialize;
565       target_func_graph->AppendOrderList(cnode);
566       continue;
567     }
568     auto repl_cnode = dyn_cast<CNode>(it->second);
569     if (repl_cnode) {
570       target_func_graph->AppendOrderList(repl_cnode);
571     }
572   }
573 }
574 
Run()575 void Cloner::Run() {
576   if (todo_.empty()) {
577     return;
578   }
579 
580   if (type_ < kLifting) {
581     // Basic and Inline Clone
582     FuncGraphVector func_graphs;
583     (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs),
584                          [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; });
585     manager_ = Manage(func_graphs, false);
586     CloneNodes();
587     LinkEdges();
588     SetDefaults();
589   } else {
590     // Lifting Clone
591     CloneInfo item = todo_.back();
592     manager_ = Manage(item.origin);
593     LiftParameters(item.origin);
594   }
595 }
596 
CloneNodes()597 void Cloner::CloneNodes() {
598   while (!todo_.empty()) {
599     CloneInfo item = todo_.back();
600     todo_.pop_back();
601 
602     bool is_inline = (item.target != nullptr);
603     FuncGraphPtr func_graph = item.origin;
604     FuncGraphPtr target_func_graph = item.target;
605     (void)graph_set_.insert(func_graph);
606 
607     if (!CheckStatus(func_graph, is_inline)) {
608       continue;
609     }
610 
611     if (is_inline) {
612       InlineCloneParameters(func_graph, item.params);
613       CloneAllNodes(func_graph, target_func_graph);
614     } else {
615       SetFuncGraphInfo(func_graph, &target_func_graph);
616       CloneParameters(func_graph, target_func_graph);
617       CloneAllNodes(func_graph, target_func_graph);
618       CloneFuncGraphValueNodes(func_graph, target_func_graph);
619       CloneFuncGraphDefaultValues(func_graph, target_func_graph);
620     }
621 
622     CloneValueNodes(func_graph);
623     AddChildGraphs(func_graph);
624     AddTotalGraphs(func_graph);
625     status_[func_graph] = is_inline;
626   }
627 }
628 
LinkEdges()629 void Cloner::LinkEdges() {
630   for (auto &node_pair : nodes_) {
631     CNodePtr old_node = node_pair.first;
632     CNodePtr new_node = node_pair.second;
633     MS_EXCEPTION_IF_NULL(old_node);
634     MS_EXCEPTION_IF_NULL(new_node);
635     for (auto &input : old_node->inputs()) {
636       auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input];
637       new_node->add_input(new_input);
638     }
639   }
640 }
641 
642 // For the graphs cloned, update its default value map to the cloned nodes
SetDefaults()643 void Cloner::SetDefaults() {
644   for (auto &item : graph_set_) {
645     MS_EXCEPTION_IF_NULL(item);
646     if (repl_func_graph_.count(item) != 0) {
647       for (auto &param_def : item->parameter_default_value()) {
648         MS_EXCEPTION_IF_NULL(repl_func_graph_[item]);
649         if (repl_node_.count(param_def.second) != 0) {
650           repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]);
651         } else {
652           repl_func_graph_[item]->set_param_default_value(param_def.first, param_def.second);
653         }
654       }
655     }
656   }
657 }
658 
CloneDisconnected(const AnfNodePtr & root)659 AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) {
660   MS_EXCEPTION_IF_NULL(root);
661   if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) {
662     MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner.";
663   }
664   CloneNode(root, repl_func_graph_[root->func_graph()]);
665   auto iter = repl_node_.find(root);
666   if (iter != repl_node_.end()) {
667     return iter->second;
668   }
669   MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << ".";
670 }
671 
operator [](const AnfNodePtr & node)672 AnfNodePtr Cloner::operator[](const AnfNodePtr &node) {
673 #ifdef ENABLE_PROFILE
674   double time = GetTime();
675 #endif
676   Run();
677 #ifdef ENABLE_PROFILE
678   MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerNode", GetTime() - time);
679 #endif
680   return ((repl_node_.count(node) == 0) ? node : repl_node_[node]);
681 }
682 
operator [](const FuncGraphPtr & func_graph)683 FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) {
684 #ifdef ENABLE_PROFILE
685   double time = GetTime();
686 #endif
687   Run();
688 #ifdef ENABLE_PROFILE
689   MsProfile::StatTime("func_graph_cloner_run.FuncGraphClonerGraph", GetTime() - time);
690 #endif
691   return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]);
692 }
693 
BasicClone(const FuncGraphPtr & func_graph,bool clone_value_nodes)694 FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph, bool clone_value_nodes) {
695   MS_EXCEPTION_IF_NULL(func_graph);
696   Cloner cloner({func_graph}, clone_value_nodes, true, true, std::make_shared<TraceCopy>(), nullptr);
697   return cloner[func_graph];
698 }
699 
InlineClone(const FuncGraphPtr & func_graph,const FuncGraphPtr & target_func_graph,const AnfNodePtrList & func_graph_args,const ScopePtr & scope)700 AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph,
701                        const AnfNodePtrList &func_graph_args, const ScopePtr &scope) {
702   MS_EXCEPTION_IF_NULL(func_graph);
703   MS_EXCEPTION_IF_NULL(target_func_graph);
704   Cloner cloner({}, false);
705   if (scope != nullptr) {
706     cloner.set_scope(scope);
707   }
708   cloner.AddClone(func_graph, target_func_graph, func_graph_args, kInline);
709   return cloner[func_graph->output()];
710 }
711 
LiftingClone(const FuncGraphPtr & func_graph)712 FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) {
713   MS_EXCEPTION_IF_NULL(func_graph);
714   Cloner cloner({}, false);
715   cloner.AddClone(func_graph, nullptr, {}, kLifting);
716   return cloner[func_graph];
717 }
718 
SpecializerClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)719 ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
720   MS_EXCEPTION_IF_NULL(func_graph);
721   FuncGraphVector func_graphs = {func_graph};
722   ClonerPtr cloner =
723     std::make_shared<Cloner>(func_graphs, false, false, false, std::make_shared<TraceCopy>(), relation);
724 #ifdef ENABLE_PROFILE
725   double time = GetTime();
726 #endif
727   cloner->Run();
728 #ifdef ENABLE_PROFILE
729   MsProfile::StatTime("func_graph_cloner_run.FuncGraphSpecializer", GetTime() - time);
730 #endif
731   return cloner;
732 }
733 
TransformableClone(const FuncGraphPtr & func_graph,const TraceInfoPtr & relation)734 FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) {
735   MS_EXCEPTION_IF_NULL(func_graph);
736   TraceGuard guard(func_graph->debug_info(), relation);
737   auto new_func_graph = std::make_shared<FuncGraph>();
738 
739   auto &parameters = func_graph->parameters();
740   (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr &param) -> void {
741     MS_EXCEPTION_IF_NULL(param);
742     TraceGuard trace_guard(std::make_shared<TraceCopy>(param->debug_info()));
743     (void)new_func_graph->add_parameter()->set_abstract(param->abstract());
744   });
745 
746   Cloner cloner = Cloner();
747   cloner.AddClone(func_graph, new_func_graph, new_func_graph->parameters());
748   AnfNodePtr output = cloner[func_graph->output()];
749   new_func_graph->set_output(output);
750   new_func_graph->set_has_vararg(func_graph->has_vararg());
751   new_func_graph->set_has_kwarg(func_graph->has_kwarg());
752   new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count());
753   new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
754   new_func_graph->set_is_generate(func_graph->is_generated());
755   new_func_graph->set_stub(func_graph->stub());
756   new_func_graph->set_switch_input(func_graph->switch_input());
757   new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
758   for (auto &item : func_graph->parameter_default_value()) {
759     new_func_graph->set_param_default_value(item.first, cloner[item.second]);
760   }
761   if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) {
762     new_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true);
763   }
764   if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
765     new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
766   }
767   new_func_graph->set_stage(func_graph->stage());
768 
769   return new_func_graph;
770 }
771 }  // namespace mindspore
772