• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <set>
21 #include <tuple>
22 #include <unordered_set>
23 #include <utility>
24 
25 #include "backend/kernel_compiler/common_utils.h"
26 #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
27 #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
28 #include "backend/kernel_compiler/kernel.h"
29 #include "backend/session/anf_runtime_algorithm.h"
30 #include "backend/optimizer/common/const_input_to_attr_registry.h"
31 #include "ir/func_graph_cloner.h"
32 #include "ir/func_graph.h"
33 #include "pipeline/jit/parse/python_adapter.h"
34 #include "pipeline/jit/action.h"
35 #include "utils/context/graph_kernel_flags.h"
36 #include "vm/segment_runner.h"
37 #if ENABLE_D
38 #include "runtime/device/ascend/kernel_select_ascend.h"
39 #elif ENABLE_GPU
40 #include "runtime/device/gpu/kernel_info_setter.h"
41 #endif
42 
43 namespace mindspore {
44 namespace opt {
45 namespace {
IsMakeTupleOut(const AnfNodePtr & out,AnfNodePtrList * real_outs)46 bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
47   MS_EXCEPTION_IF_NULL(real_outs);
48   if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
49     auto &inputs = out->cast<CNodePtr>()->inputs();
50     for (size_t i = 1; i < inputs.size(); ++i) {
51       real_outs->push_back(inputs[i]);
52     }
53     return true;
54   }
55 
56   if (auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); fg != nullptr) {
57     auto fg_out = fg->output();
58     if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) {
59       auto inputs = fg_out->cast<CNodePtr>()->inputs();
60       for (size_t i = 1; i < inputs.size(); ++i) {
61         real_outs->push_back(inputs[i]);
62       }
63       return true;
64     }
65   }
66   return false;
67 }
68 
EliminateMakeTuple(const FuncGraphPtr & fg,const FuncGraphManagerPtr & mng)69 AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
70   AnfNodePtrList outs;
71   auto out_node = fg->output();
72   if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
73     std::vector<AnfNodePtr> output_args;
74     auto out_cnode = out_node->cast<CNodePtr>();
75     for (auto out : out_cnode->inputs()) {
76       if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
77         auto inputs = out->cast<CNodePtr>()->inputs();
78         for (size_t i = 1; i < inputs.size(); ++i) {
79           output_args.push_back(inputs[i]);
80         }
81       } else {
82         output_args.push_back(out);
83       }
84     }
85     if (output_args.size() != out_cnode->inputs().size()) {
86       auto new_out = fg->NewCNode(output_args);
87       mng->Replace(out_node, new_out);
88     }
89 
90     for (size_t i = 1; i < output_args.size(); ++i) {
91       outs.push_back(output_args[i]);
92     }
93     return outs;
94   }
95 
96   outs.push_back(out_node);
97   return outs;
98 }
99 
GenJson(const AnfNodePtrList & op_nodes,const std::pair<AnfNodePtrList,AnfNodePtrList> & in_and_out,const DumpOption & dump_option,nlohmann::json * op_desc,std::map<std::string,AnfNodePtr> * address_node_map=nullptr)100 bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, AnfNodePtrList> &in_and_out,
101              const DumpOption &dump_option, nlohmann::json *op_desc,
102              std::map<std::string, AnfNodePtr> *address_node_map = nullptr) {
103   kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option);
104   if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, in_and_out.first, in_and_out.second)) {
105     MS_LOG(ERROR) << "Collect json desc failed.";
106     return false;
107   }
108 
109   *op_desc = akg_kernel_json_generator.kernel_json();
110   if (address_node_map != nullptr) {
111     *address_node_map = akg_kernel_json_generator.address_node_map();
112   }
113   std::string fused_name;
114   std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
115     (void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
116   });
117   MS_LOG(DEBUG) << "Collect fusion json: " << fused_name;
118   return true;
119 }
120 
ConvertToScalarTensor(const AnfNodePtr & value_node)121 AnfNodePtr ConvertToScalarTensor(const AnfNodePtr &value_node) {
122   auto tensor = GetValueNode<tensor::TensorPtr>(value_node);
123   MS_EXCEPTION_IF_NULL(tensor);
124   auto type_id = tensor->data_type();
125   ShapeVector new_shape;
126   auto origin_ndim = IntToSize(tensor->DataDim());
127   for (size_t i = 0; i < origin_ndim; ++i) {
128     new_shape.push_back(1);
129   }
130   tensor::TensorPtr scalar_tensor = std::make_shared<tensor::Tensor>(type_id, new_shape);
131   scalar_tensor->set_device_info(tensor->device_info());
132   auto data_ptr = scalar_tensor->data_c();
133   MS_EXCEPTION_IF_NULL(data_ptr);
134   auto itemsize = static_cast<size_t>(tensor->data().itemsize());
135   if (memcpy_s(data_ptr, static_cast<size_t>(itemsize), tensor->data_c(), itemsize) != 0) {
136     MS_LOG(EXCEPTION) << "Failed to copy data from tensor into scalar.";
137   }
138 
139   ValueNodePtr new_value_node = std::make_shared<ValueNode>(scalar_tensor);
140   new_value_node->set_abstract(scalar_tensor->ToAbstract());
141   new_value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
142   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
143   kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{GetFormat(value_node)});
144   kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{type_id});
145   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
146 
147   return new_value_node;
148 }
149 
ReplaceTensorWithScalar(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & scalar_tensors)150 void ReplaceTensorWithScalar(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &scalar_tensors) {
151   MS_EXCEPTION_IF_NULL(fg);
152   if (scalar_tensors.empty()) {
153     return;
154   }
155 
156   auto sub_mng = fg->manager();
157   if (sub_mng == nullptr) {
158     sub_mng = Manage(fg, true);
159     fg->set_manager(sub_mng);
160   }
161 
162   std::map<AnfNodePtr, AnfNodePtr> to_be_replaced;
163   for (auto scalar_tensor_node : scalar_tensors) {
164     auto scalar = ConvertToScalarTensor(scalar_tensor_node);
165     auto format = GetFormat(scalar_tensor_node);
166     auto dst_shape_vec = GetShape(scalar_tensor_node);
167     AnfNodePtrList new_broadcast_inputs = {NewValueNode(prim::kPrimBroadcastTo), scalar};
168     auto broadcast_node = CreateCNode(new_broadcast_inputs, fg,
169                                       {.format = format, .shape = dst_shape_vec, .type = GetType(scalar_tensor_node)});
170     auto device_shape = GetDeviceShape(scalar_tensor_node);
171     SetNodeAttrSafely("shape", MakeValue(device_shape), broadcast_node);
172     to_be_replaced[scalar_tensor_node] = broadcast_node;
173   }
174 
175   for (auto [old_value_node, new_node] : to_be_replaced) {
176     sub_mng->Replace(old_value_node, new_node);
177   }
178 }
179 }  // namespace
180 
GetOutputAbstract(const AnfNodePtr & node,size_t output_idx)181 AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
182   auto out_spec = node->abstract();
183   if (out_spec->isa<abstract::AbstractTuple>()) {
184     return out_spec->cast<abstract::AbstractTuplePtr>()->elements()[output_idx];
185   }
186   return out_spec;
187 }
188 
ConvertNonscalarTensorToParameter(const FuncGraphPtr & fg,AnfNodePtrList * inputs_ptr)189 bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
190   MS_EXCEPTION_IF_NULL(inputs_ptr);
191   auto nodes = TopoSort(fg->get_return());
192 
193   std::vector<std::pair<tensor::TensorPtr, AnfNodePtrList>> v_replace;
194   std::vector<AnfNodePtr> scalar_tensors;
195   for (const auto &node : nodes) {
196     if (!node->isa<CNode>()) {
197       continue;
198     }
199     auto &inputs = node->cast<CNodePtr>()->inputs();
200     for (size_t i = 1; i < inputs.size(); ++i) {
201       const auto &tnode = inputs[i];
202       auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
203       if (tensor == nullptr || tensor->DataSize() == 1) {
204         continue;
205       }
206       auto tensor_iter = std::find_if(
207         v_replace.begin(), v_replace.end(),
208         [&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
209       if (tensor_iter == v_replace.end()) {
210         (void)v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
211       } else {
212         tensor_iter->second.push_back(tnode);
213       }
214     }
215   }
216 
217   ReplaceTensorWithScalar(fg, scalar_tensors);
218 
219   if (v_replace.empty()) {
220     return false;
221   }
222 
223   auto mng = fg->manager();
224   if (mng == nullptr) {
225     mng = Manage(fg, false);
226     fg->set_manager(mng);
227   }
228 
229   auto &inputs = *inputs_ptr;
230   for (auto iter : v_replace) {
231     auto value_nodes = iter.second;
232     if (value_nodes.empty()) {
233       MS_LOG(EXCEPTION) << "Invalid value in map!";
234     }
235 
236     auto vnode = value_nodes[0];
237     auto parameter = fg->add_parameter();
238     parameter->set_abstract(vnode->abstract());
239     parameter->set_kernel_info(vnode->kernel_info_ptr());
240     for (const auto &value_node : value_nodes) {
241       mng->Replace(value_node, parameter);
242     }
243 
244     inputs.push_back(vnode);
245   }
246 
247   return true;
248 }
249 
250 // Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
MixedNodesTransToGraph(const AnfNodePtrList & fuse_nodes,AnfNodePtrList * src_outputs)251 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
252                                                                                 AnfNodePtrList *src_outputs) {
253   FuncGraphPtr fg;
254   AnfNodePtrList inputs;
255   AnfNodePtrList outputs;
256   AnfNodePtrList *soutputs = (src_outputs != nullptr) ? src_outputs : &outputs;
257   std::tie(fg, inputs, *soutputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
258 
259   FuncGraphManagerPtr mng = fg->manager();
260   if (mng == nullptr) {
261     mng = Manage(fg, false);
262     fg->set_manager(mng);
263   }
264 
265   // Inline origin graphkernel
266   auto cnodes = fg->GetOrderedCnodes();
267   for (const auto &n : cnodes) {
268     if (!AnfAlgo::IsGraphKernel(n)) {
269       continue;
270     }
271     auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
272     AnfNodePtrList ins;
273     ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
274     auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
275     mng->Replace(n, out);
276   }
277 
278   EliminateMakeTuple(fg, mng);
279   ConvertNonscalarTensorToParameter(fg, &inputs);
280 
281   outputs.clear();
282   kernel::GetFuncGraphOutputNodes(fg, &outputs);
283   return std::make_tuple(fg, inputs, outputs);
284 }
285 
286 // Rebuild as node inputs or outputs have changed, processor comes from node itself
BuildSelectKernelBuildInfo(const std::vector<std::string> & inputs_format,const std::vector<TypeId> & inputs_type,const std::vector<std::string> & output_formats,const std::vector<TypeId> & output_types,const AnfNodePtr & node)287 kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
288                                                       const std::vector<TypeId> &inputs_type,
289                                                       const std::vector<std::string> &output_formats,
290                                                       const std::vector<TypeId> &output_types, const AnfNodePtr &node) {
291   kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
292   graph_info_builder.SetInputsFormat(inputs_format);
293   graph_info_builder.SetInputsDeviceType(inputs_type);
294   graph_info_builder.SetOutputsFormat(output_formats);
295   graph_info_builder.SetOutputsDeviceType(output_types);
296   graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
297   graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
298   graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
299   return graph_info_builder.Build();
300 }
301 
302 // Build for new node, processor comes from context
BuildSelectKernelBuildInfo(const std::vector<std::string> & inputs_format,const std::vector<TypeId> & inputs_type,const std::vector<std::string> & output_formats,const std::vector<TypeId> & output_types)303 kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
304                                                       const std::vector<TypeId> &inputs_type,
305                                                       const std::vector<std::string> &output_formats,
306                                                       const std::vector<TypeId> &output_types) {
307   kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
308   graph_info_builder.SetInputsFormat(inputs_format);
309   graph_info_builder.SetInputsDeviceType(inputs_type);
310   graph_info_builder.SetOutputsFormat(output_formats);
311   graph_info_builder.SetOutputsDeviceType(output_types);
312   graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
313   graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
314   graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
315   return graph_info_builder.Build();
316 }
317 
SetNewKernelInfo(const AnfNodePtr & new_node,const FuncGraphPtr & fg,const AnfNodePtrList & inputs,const AnfNodePtrList & outputs)318 void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
319                       const AnfNodePtrList &outputs) {
320   std::vector<std::string> graph_input_format;
321   std::vector<TypeId> graph_input_type;
322   std::vector<std::string> graph_output_format;
323   std::vector<TypeId> graph_output_type;
324   for (size_t i = 0; i < inputs.size(); ++i) {
325     auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
326     if (kernel_with_index.first->isa<ValueNode>()) {
327       auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
328       MS_EXCEPTION_IF_NULL(tensor);
329       (void)graph_input_format.emplace_back(kOpFormat_DEFAULT);
330       (void)graph_input_type.emplace_back(tensor->data_type());
331     } else {
332       auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
333       (void)graph_input_format.emplace_back(std::move(input_format));
334       auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
335       (void)graph_input_type.emplace_back(input_type);
336     }
337     auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
338     fg->parameters()[i]->set_abstract(input_abs);
339     fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
340     kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
341     para_info_builder.SetOutputsFormat({graph_input_format.back()});
342     para_info_builder.SetOutputsDeviceType({graph_input_type.back()});
343     para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
344     para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
345     AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i].get());
346   }
347   auto new_outputs = outputs;
348   if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) {
349     std::vector<AnfNodePtr> real_outs;
350     if (IsMakeTupleOut(outputs[0], &real_outs)) {
351       new_outputs = real_outs;
352     }
353   }
354   for (size_t i = 0; i < new_outputs.size(); ++i) {
355     auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0);
356     auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
357     auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
358     graph_output_format.push_back(output_format);
359     graph_output_type.push_back(output_type);
360   }
361   kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
362   graph_info_builder.SetInputsFormat(graph_input_format);
363   graph_info_builder.SetInputsDeviceType(graph_input_type);
364   graph_info_builder.SetOutputsFormat(graph_output_format);
365   graph_info_builder.SetOutputsDeviceType(graph_output_type);
366   graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
367   graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
368   graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
369   auto graph_selected_info = graph_info_builder.Build();
370   AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
371 }
372 
CreateNewFuseCNode(const FuncGraphPtr & func_graph,const FuncGraphPtr & fg,const AnfNodePtrList & inputs,const AnfNodePtrList & outputs)373 AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
374                               const AnfNodePtrList &outputs) {
375   auto func_node = NewValueNode(fg);
376   std::vector<AnfNodePtr> fn_inputs;
377   fn_inputs.push_back(func_node);
378   fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
379   auto fuse_cnode = func_graph->NewCNode(fn_inputs);
380   // Set output abstract
381   if (outputs.size() > 1) {
382     std::vector<AbstractBasePtr> out_specs;
383     for (size_t i = 0; i < outputs.size(); ++i) {
384       out_specs.push_back(outputs[i]->abstract());
385     }
386     auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
387     fuse_cnode->set_abstract(out_spec);
388   } else {
389     fuse_cnode->set_abstract(outputs[0]->abstract());
390   }
391   // Set parameter abstract.
392   for (size_t i = 0; i < inputs.size(); ++i) {
393     auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
394     auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
395     fg->parameters()[i]->set_abstract(input_abs);
396   }
397   return fuse_cnode;
398 }
399 
ReplaceNewFuseCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & new_fuse_cnode,const AnfNodePtrList & outputs)400 void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
401                          const AnfNodePtrList &outputs) {
402   MS_EXCEPTION_IF_NULL(func_graph);
403   auto mng = func_graph->manager();
404   MS_EXCEPTION_IF_NULL(mng);
405   // single out
406   if (outputs.size() == 1) {
407     mng->Replace(outputs[0], new_fuse_cnode);
408     return;
409   }
410 
411   std::vector<AnfNodePtr> fn_inputs;
412   size_t offset = 0;
413   for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
414     AnfNodePtrList real_outs;
415     // not make tuple out, replace
416     if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
417       fn_inputs.clear();
418       fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
419       fn_inputs.push_back(new_fuse_cnode);
420       fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx + offset))));
421       auto new_out = func_graph->NewCNode(fn_inputs);
422       new_out->set_abstract(outputs[out_idx]->abstract());
423       mng->Replace(outputs[out_idx], new_out);
424       continue;
425     }
426 
427     // the out is make tuple , modify the get_item node's value
428     auto users = mng->node_users()[outputs[out_idx]];
429     for (auto &user : users) {
430       auto use_node = user.first;
431       if (!use_node->isa<CNode>() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) {
432         continue;
433       }
434       auto get_item_cnode = use_node->cast<CNodePtr>();
435       auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
436       MS_EXCEPTION_IF_NULL(value_input);
437       auto value_node = value_input->cast<ValueNodePtr>();
438       MS_EXCEPTION_IF_NULL(value_node);
439       auto item_idx = GetValue<int64_t>(value_node->value());
440       int64_t new_item_idx = SizeToLong(out_idx + offset) + item_idx;
441       fn_inputs.clear();
442       fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
443       fn_inputs.push_back(new_fuse_cnode);
444       fn_inputs.push_back(NewValueNode(new_item_idx));
445       auto new_out = func_graph->NewCNode(fn_inputs);
446       new_out->set_abstract(get_item_cnode->abstract());
447       mng->Replace(get_item_cnode, new_out);
448     }
449 
450     offset += real_outs.size() - 1;
451   }
452 }
453 
FuseNodesToSubGraph(const std::vector<AnfNodePtr> & fuse_nodes,const FuncGraphPtr & kernel_graph,const std::string & postfix)454 std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
455                                                            const FuncGraphPtr &kernel_graph,
456                                                            const std::string &postfix) {
457   auto mng = kernel_graph->manager();
458   if (mng == nullptr) {
459     mng = Manage(kernel_graph, true);
460     kernel_graph->set_manager(mng);
461   }
462 
463   FuncGraphPtr fg;
464   AnfNodePtrList inputs;
465   AnfNodePtrList src_outputs;
466   AnfNodePtrList outputs;
467 
468   std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs);
469   auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
470   SetNewKernelInfo(fuse_new_node, fg, inputs, outputs);
471   // Handle get-item probleam.
472   ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs);
473 
474   // set graphKernel attr
475   std::string fuse_op_name = "";
476   for (auto &fuse_node : fuse_nodes) {
477     if (IsPrimitiveCNode(fuse_node)) {
478       fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
479     } else if (AnfAlgo::IsGraphKernel(fuse_node)) {
480       auto fuse_cnode = fuse_node->cast<CNodePtr>();
481       MS_EXCEPTION_IF_NULL(fuse_cnode);
482       auto graph_kernel_fg = GetValueNode<FuncGraphPtr>(fuse_cnode->input(kAnfPrimitiveIndex));
483       auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
484       auto fuse_fg_name = GetValue<std::string>(fg_flag_val);
485       fuse_op_name += fuse_fg_name + "_";
486     }
487   }
488   fuse_op_name += postfix;
489   fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
490 
491   return std::make_tuple(fuse_new_node, src_outputs);
492 }
493 
AnfToJsonDesc(const AnfNodePtrList & nodes,const DumpOption & dump_option,nlohmann::json * op_desc,std::map<std::string,AnfNodePtr> * address_node_map)494 bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
495                    std::map<std::string, AnfNodePtr> *address_node_map) {
496   MS_EXCEPTION_IF_NULL(op_desc);
497   if (nodes.empty()) {
498     MS_LOG(ERROR) << "Input nodes is empty.";
499     return false;
500   }
501   bool has_graph_kernel = std::any_of(nodes.begin(), nodes.end(), AnfAlgo::IsGraphKernel);
502   bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1;
503 
504   FuncGraphPtr fg;
505   AnfNodePtrList op_nodes, inputs, outputs;
506   if (is_single_graph_kernel) {
507     fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
508     kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
509   } else if (!has_graph_kernel) {
510     std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
511     op_nodes = nodes;
512   } else {
513     // When there are basic and composite ops, the composite ops should be inline to the basic ones' graph,
514     // so a new graph generation should be done (because they may in the main graph!).
515     // If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now.
516     MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!";
517   }
518   std::pair<AnfNodePtrList, AnfNodePtrList> in_and_out = std::make_pair(inputs, outputs);
519   return GenJson(op_nodes, in_and_out, dump_option, op_desc, address_node_map);
520 }
521 
AnfToJsonDesc(const AnfNodePtrList & nodes,const DumpOption & dump_option,nlohmann::json * op_desc)522 bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc) {
523   MS_EXCEPTION_IF_NULL(op_desc);
524   if (nodes.empty()) {
525     MS_LOG(ERROR) << "Input nodes is empty.";
526     return false;
527   }
528 
529   FuncGraphPtr fg;
530   AnfNodePtrList op_nodes, inputs, outputs;
531   if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) {
532     fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
533   } else {
534     std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(nodes);
535     inputs.clear();
536     outputs.clear();
537   }
538 
539   kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
540 
541   auto mng = fg->manager();
542   if (mng == nullptr) {
543     mng = Manage(fg, false);
544     fg->set_manager(mng);
545   }
546   std::pair<AnfNodePtrList, AnfNodePtrList> in_and_out = std::make_pair(inputs, outputs);
547   return GenJson(op_nodes, in_and_out, dump_option, op_desc);
548 }
549 
AnfToJsonDesc(const std::vector<AnfNodePtrList> & graphs,const DumpOption & dump_option,nlohmann::json * op_desc)550 bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc) {
551   MS_EXCEPTION_IF_NULL(op_desc);
552   std::vector<nlohmann::json> graphs_desc;
553   for (auto const &graph_nodes : graphs) {
554     nlohmann::json desc;
555     if (!AnfToJsonDesc(graph_nodes, dump_option, &desc)) {
556       MS_LOG(ERROR) << "Collect json desc failed.";
557       return false;
558     }
559     graphs_desc.push_back(desc);
560   }
561   if (graphs_desc.empty()) {
562     MS_LOG(ERROR) << "Collect zero json desc.";
563     return false;
564   }
565 
566   if (graphs_desc.size() > 1) {
567     nlohmann::json op_json_desc;
568     op_json_desc[kJsonKeyMultiGraph] = true;
569     op_json_desc[kJsonKeyGraphDesc] = graphs_desc;
570     *op_desc = op_json_desc;
571     return true;
572   }
573 
574   *op_desc = graphs_desc[0];
575   return true;
576 }
577 
JsonDescToAnf(const std::string & json_desc)578 FuncGraphPtr JsonDescToAnf(const std::string &json_desc) {
579   kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
580   auto fg = akg_kernel_json_decoder.DecodeFusedNodes(json_desc);
581   if (fg == nullptr) {
582     MS_LOG(ERROR) << "Akg decode json to graph failed.";
583     return nullptr;
584   }
585   return fg;
586 }
587 
ExtractGraphKernelName(const AnfNodePtrList & cnodes,const string & prefix,const string & postfix)588 std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) {
589   std::stringstream name;
590   if (prefix != "") {
591     name << prefix << "_";
592   }
593   for (const auto &node : cnodes) {
594     if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
595       name << AnfAlgo::GetCNodeName(node) << "_";
596     }
597   }
598   if (postfix != "") {
599     name << postfix;
600   }
601   return name.str();
602 }
603 
ResetKernelInfo(const AnfNodePtr & node,KernelType kernel_type)604 void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
605   auto cnode = node->cast<CNodePtr>();
606   MS_EXCEPTION_IF_NULL(cnode);
607 #if ENABLE_D
608   device::ascend::SetKernelInfo(cnode, kernel_type);
609 #elif ENABLE_GPU
610   cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
611   device::gpu::SetKernelInfo(cnode, kernel_type);
612 #endif
613 }
614 
GetFormat(const AnfNodePtr & node)615 std::string GetFormat(const AnfNodePtr &node) { return AnfAlgo::GetOutputFormat(node, 0); }
616 
GetType(const AnfNodePtr & node)617 TypePtr GetType(const AnfNodePtr &node) {
618   const auto &abstract = node->abstract();
619   auto type = abstract->BuildType();
620   MS_EXCEPTION_IF_NULL(type);
621   return type;
622 }
623 
GetShape(const AnfNodePtr & node)624 ShapeVector GetShape(const AnfNodePtr &node) {
625   auto abstract = node->abstract();
626   MS_EXCEPTION_IF_NULL(abstract);
627   auto shape = abstract->GetShapeTrack();
628   if (shape == nullptr || !shape->isa<abstract::Shape>()) {
629     MS_LOG(EXCEPTION) << "Cannot get shape from " << node->fullname_with_scope();
630   }
631   auto shape_vec = shape->cast<abstract::ShapePtr>()->shape();
632   if (shape_vec.empty()) {
633     shape_vec.push_back(1);
634   }
635   return shape_vec;
636 }
637 
GetDeviceShape(const AnfNodePtr & node)638 ShapeVector GetDeviceShape(const AnfNodePtr &node) {
639   ShapeVector res_device_shape;
640   auto device_shape = AnfAlgo::GetOutputDeviceShape(node, 0);
641   if (device_shape.empty()) {
642     res_device_shape.push_back(1);
643   } else {
644     (void)std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(res_device_shape), SizeToLong);
645   }
646   return res_device_shape;
647 }
648 
GetReduceAxis(const AnfNodePtr & node)649 std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {
650   auto prim = GetCNodePrimitive(node);
651   MS_EXCEPTION_IF_NULL(prim);
652   const auto &attrs = prim->attrs();
653   auto iter = attrs.find("axis");
654   if (iter == attrs.end()) {
655     MS_LOG(EXCEPTION) << "Origin node have no attributes!";
656   }
657 
658   std::vector<int64_t> axis;
659 
660   auto &v = iter->second;
661   if (v->isa<ValueList>() || v->isa<ValueTuple>()) {
662     auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value();
663     for (auto value : vec) {
664       if (value->isa<Int64Imm>()) {
665         axis.push_back(GetValue<int64_t>(value));
666       } else {
667         MS_LOG(EXCEPTION) << "Reduce axis type should be int64!";
668       }
669     }
670   } else if (v->isa<Int64Imm>()) {
671     axis.push_back(GetValue<int64_t>(v));
672   } else {
673     MS_LOG(EXCEPTION) << "Reduce axis should be a list or tuple!";
674   }
675 
676   return axis;
677 }
678 
CreateCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & func_graph,const DataInfo & out_info,bool use_fake_abstract)679 CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info,
680                      bool use_fake_abstract) {
681   // Limitation: 1. Node's attributes should be set out of this function; 2. only one output.
682   MS_EXCEPTION_IF_NULL(out_info.type);
683   auto out_type = out_info.type;
684   if (auto otype = out_info.type->cast<TensorTypePtr>(); otype != nullptr) {
685     out_type = otype->element();
686   }
687 
688   // Create CNode.
689   auto cnode = func_graph->NewCNode(inputs);
690   MS_EXCEPTION_IF_NULL(cnode);
691 
692   // Setup abstract.
693   if (use_fake_abstract) {
694     auto abs_shape = kernel::GetFakeAbstractShape(out_info.shape, out_info.format);
695     auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, abs_shape);
696     cnode->set_abstract(abs_tensor);
697   } else {
698     auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, out_info.shape);
699     cnode->set_abstract(abs_tensor);
700   }
701 
702   // Setup kernel info.
703   auto kernel_info = std::make_shared<device::KernelInfo>();
704   cnode->set_kernel_info(kernel_info);
705   std::vector<size_t> feature_map_input_indexs;
706   kernel_info->set_feature_map_flag(false);
707   for (size_t i = 1; i < inputs.size(); ++i) {
708     if (AnfAlgo::IsFeatureMapOutput(inputs[i])) {
709       kernel_info->set_feature_map_flag(true);
710       feature_map_input_indexs.push_back(i);
711     }
712   }
713   if (inputs.size() == 1) {
714     kernel_info->set_feature_map_flag(true);
715   }
716   if (AnfAlgo::IsRealKernel(cnode)) {
717     // if the node only has the primitive(such as getNext) or the node's input has a feature map input
718     // then the node's output is a feature map output
719     SetNodeAttrSafely(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
720     SetNodeAttrSafely(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
721   }
722 
723   // Setup kernel build info.
724   std::vector<std::string> input_formats;
725   std::vector<TypeId> input_types;
726   for (size_t i = 1; i < inputs.size(); ++i) {
727     auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
728     auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
729     input_formats.push_back(input_format);
730     auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
731     input_types.push_back(input_type);
732   }
733 
734   std::vector<std::string> output_formats = {out_info.format};
735   std::vector<TypeId> output_types = {out_type->type_id()};
736 
737   kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
738   info_builder.SetInputsFormat(input_formats);
739   info_builder.SetInputsDeviceType(input_types);
740   info_builder.SetOutputsFormat(output_formats);
741   info_builder.SetOutputsDeviceType(output_types);
742   info_builder.SetProcessor(kernel::GetProcessorFromContext());
743   info_builder.SetKernelType(KernelType::AKG_KERNEL);
744   info_builder.SetFusionType(kernel::FusionType::OPAQUE);
745   auto selected_info = info_builder.Build();
746   AnfAlgo::SetSelectKernelBuildInfo(selected_info, cnode.get());
747 
748   func_graph->AddNode(cnode);
749   return cnode;
750 }
751 
SetNodeAttrSafely(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)752 void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
753   // Make CNode safe to set attr firstly.
754   auto cnode = node->cast<CNodePtr>();
755   if (cnode == nullptr) {
756     return;
757   }
758   AnfNodePtrList new_inputs = {NewValueNode(AnfAlgo::GetCNodePrimitive(cnode)->Clone())};
759   auto inputs = cnode->inputs();
760   new_inputs.insert(new_inputs.end(), inputs.begin() + 1, inputs.end());
761   cnode->set_inputs(new_inputs);
762 
763   // Set attr secondly.
764   AnfAlgo::SetNodeAttr(key, value, node);
765 }
766 
IsKeepBasicNode(const AnfNodePtr & node)767 bool IsKeepBasicNode(const AnfNodePtr &node) {
768   MS_EXCEPTION_IF_NULL(node);
769   if (!node->isa<CNode>()) {
770     return false;
771   }
772   auto cnode = node->cast<CNodePtr>();
773   MS_EXCEPTION_IF_NULL(cnode);
774 
775   // Dynamic shape is unsupported yet.
776   if (AnfAlgo::HasDynamicShapeFlag(AnfAlgo::GetCNodePrimitive(cnode))) {
777     return true;
778   }
779 
780   static const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
781                                                             "aggregate", "aggregate_input_indexx"};
782   // If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
783   if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
784                   [&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); })) {
785     return true;
786   }
787   if (AnfAlgo::GetBooleanAttr(cnode, "skip")) {
788     return true;
789   }
790   return false;
791 }
792 
OpListFilter(std::vector<PrimitivePtr> * ops,const std::vector<std::string> & enable_ops_only,const std::vector<std::string> & enable_ops,const std::vector<std::string> & disable_ops)793 void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
794                   const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops) {
795   auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
796   if (!enable_ops_only.empty()) {
797     ops->clear();
798     (void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim);
799   } else {
800     if (!enable_ops.empty()) {
801       (void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim);
802     }
803     if (!disable_ops.empty()) {
804       auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) {
805         return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
806       });
807       (void)ops->erase(iter, ops->end());
808     }
809   }
810 }
811 
AnfGraph2LiteGraph(const FuncGraphPtr & func_graph)812 graphkernel::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph) {
813   graphkernel::LiteGraph::GraphBuilder gb(GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)));
814   std::map<AnfNodePtr, graphkernel::NodePtr> node_map;
815   auto todos = TopoSort(func_graph->output());
816   const auto &params = func_graph->parameters();
817   auto ExtractBuildInfo = [](const AnfNodePtr &node) {
818     auto shape = GetDeviceShape(node);
819     auto type = AnfAlgo::GetOutputDeviceDataType(node, 0);
820     auto format = AnfAlgo::GetOutputFormat(node, 0);
821     return graphkernel::NodeBase({shape, type, format});
822   };
823   // set inputs
824   for (size_t i = 0; i < params.size(); i++) {
825     node_map[params[i]] = gb.Parameter(ExtractBuildInfo(params[i]), std::string("input_") + std::to_string(i));
826   }
827   // set ops
828   for (auto node : todos) {
829     auto cnode = node->cast<CNodePtr>();
830     if (cnode == nullptr) continue;
831     if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) break;
832     auto prim = AnfAlgo::GetCNodePrimitive(cnode);
833     MS_EXCEPTION_IF_NULL(prim);
834     graphkernel::NodePtrList inputs;
835     (void)std::transform(cnode->inputs().begin() + 1, cnode->inputs().end(), std::back_inserter(inputs),
836                          [&node_map, &gb](const AnfNodePtr &no) {
837                            auto iter = node_map.find(no);
838                            if (iter != node_map.end()) {
839                              return iter->second;
840                            } else {
841                              auto tensor = GetValueNode<tensor::TensorPtr>(no);
842                              MS_EXCEPTION_IF_NULL(tensor);
843                              return gb.Value(tensor);
844                            }
845                          });
846     node_map[node] = gb.Op(AnfAlgo::GetCNodeName(node), ExtractBuildInfo(node), inputs, prim->attrs());
847   }
848   // set outputs
849   auto output_node = func_graph->output();
850   if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
851     graphkernel::NodePtrList outputs;
852     auto mt = output_node->cast<CNodePtr>();
853     (void)std::transform(mt->inputs().begin() + 1, mt->inputs().end(), std::back_inserter(outputs),
854                          [&node_map](const AnfNodePtr &no) { return node_map[no]; });
855     gb.SetOutputs(std::move(outputs));
856   } else {
857     gb.SetOutputs({node_map[output_node]});
858   }
859   return gb.Get();
860 }
861 
LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr & lite_graph,AnfNodePtrList * outputs)862 FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, AnfNodePtrList *outputs) {
863   auto func_graph = std::make_shared<FuncGraph>();
864   std::map<graphkernel::NodePtr, AnfNodePtr> node_map;
865   for (const auto &inp : lite_graph->inputs()) {
866     auto param = func_graph->add_parameter();
867     node_map[inp] = param;
868     auto abs_shape = kernel::GetFakeAbstractShape(inp->shape, inp->format);
869     param->set_abstract(std::make_shared<abstract::AbstractTensor>(TypeIdToType(inp->type), abs_shape));
870     param->set_kernel_info(std::make_shared<device::KernelInfo>());
871     auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
872     AnfAlgo::SetSelectKernelBuildInfo(build_info, param.get());
873   }
874   // Create CNodes.
875   for (const auto &op_node : lite_graph->GetOrderedNodes()) {
876     if (op_node->NodeType() != graphkernel::NType::Primitive) {
877       MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node";
878     }
879     auto op = std::static_pointer_cast<graphkernel::PrimOp>(op_node);
880     AnfNodePtrList inputs = {NewValueNode(std::make_shared<Primitive>(op->op(), op->attrs()))};
881     (void)std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(inputs),
882                          [&node_map](const graphkernel::NodePtr &inp) -> AnfNodePtr {
883                            auto iter = node_map.find(inp);
884                            if (iter != node_map.end()) {
885                              return iter->second;
886                            } else {
887                              if (inp->NodeType() != graphkernel::NType::Value) {
888                                MS_LOG(EXCEPTION) << "Node " << inp->name() << "should be a Value node";
889                              }
890                              auto inp_value = inp->As<graphkernel::ConstTensorNode>()->data();
891                              auto value_node = NewValueNode(inp_value);
892                              value_node->set_abstract(inp_value->ToAbstract());
893                              value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
894                              auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
895                              AnfAlgo::SetSelectKernelBuildInfo(build_info, value_node.get());
896                              return value_node;
897                            }
898                          });
899     auto cnode = CreateCNode(inputs, func_graph, {op->format, op->shape, TypeIdToType(op->type)}, true);
900     MS_EXCEPTION_IF_NULL(cnode);
901     node_map[op_node] = cnode;
902   }
903   if (lite_graph->GetOutputs().empty()) {
904     MS_LOG(EXCEPTION) << "The output of LiteGraph " << lite_graph->name() << " is empty.";
905   } else if (lite_graph->GetOutputs().size() == 1) {
906     func_graph->set_output(node_map[lite_graph->GetOutputs()[0]]);
907     if (outputs != nullptr) {
908       (void)outputs->emplace_back(func_graph->output());
909     }
910   } else {
911     AnfNodePtrList mt_inputs;
912     AbstractBasePtrList out_abs_list;
913     (void)std::transform(lite_graph->GetOutputs().begin(), lite_graph->GetOutputs().end(),
914                          std::back_inserter(mt_inputs), [&node_map, &out_abs_list](const graphkernel::NodePtr &out) {
915                            auto out_node = node_map[out];
916                            MS_EXCEPTION_IF_NULL(out_node);
917                            (void)out_abs_list.emplace_back(out_node->abstract());
918                            return out_node;
919                          });
920     auto mt = func_graph->NewCNode(prim::kPrimMakeTuple, mt_inputs);
921     mt->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
922     mt->set_kernel_info(std::make_shared<device::KernelInfo>());
923     func_graph->AddNode(mt);
924     func_graph->set_output(mt);
925     if (outputs != nullptr) {
926       *outputs = std::move(mt_inputs);
927     }
928   }
929   return func_graph;
930 }
931 
EliminateRedundantParameters(const FuncGraphPtr & func_graph,AnfNodePtrList * inputs)932 void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
933   const auto &ori_parameter = func_graph->parameters();
934   auto todos = TopoSort(func_graph->get_return());
935   std::set<AnfNodePtr> used_param;
936   for (auto node : todos) {
937     if (node->isa<Parameter>()) {
938       (void)used_param.insert(node);
939     }
940   }
941   if (used_param.size() == ori_parameter.size()) {
942     return;
943   }
944   AnfNodePtrList new_parameter, new_inputs;
945   for (size_t i = 0; i < ori_parameter.size(); ++i) {
946     if (used_param.count(ori_parameter[i])) {
947       new_parameter.push_back(ori_parameter[i]);
948       new_inputs.push_back((*inputs)[i]);
949     }
950   }
951   func_graph->set_parameters(new_parameter);
952   *inputs = std::move(new_inputs);
953 }
954 
GetValidOps(const std::vector<std::tuple<std::string,unsigned int,PrimitivePtr>> & ops_with_level,unsigned int level)955 std::vector<PrimitivePtr> GetValidOps(
956   const std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> &ops_with_level, unsigned int level) {
957   auto context_ptr = MsContext::GetInstance();
958   MS_EXCEPTION_IF_NULL(context_ptr);
959   std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
960   std::vector<PrimitivePtr> valid_ops;
961   for (const auto &[op_target, op_level, op] : ops_with_level) {
962     if (op_target == kAllTarget || op_target == target) {
963       if (level >= op_level) {
964         (void)valid_ops.emplace_back(op);
965       }
966     }
967   }
968   return valid_ops;
969 }
970 
GetFuncGraphManager(const FuncGraphPtr & func_graph)971 FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph) {
972   MS_EXCEPTION_IF_NULL(func_graph);
973   FuncGraphManagerPtr manager = func_graph->manager();
974   if (manager == nullptr) {
975     manager = Manage(func_graph, true);
976     func_graph->set_manager(manager);
977   }
978   return manager;
979 }
980 
UpdateMng(const FuncGraphManagerPtr & mng,const FuncGraphPtr & func_graph)981 void UpdateMng(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph) {
982   mng->RemoveRoots();
983   mng->KeepRoots({func_graph});
984 }
985 }  // namespace opt
986 }  // namespace mindspore
987