1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <set>
21 #include <tuple>
22 #include <unordered_set>
23 #include <utility>
24
25 #include "backend/kernel_compiler/common_utils.h"
26 #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
27 #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
28 #include "backend/kernel_compiler/kernel.h"
29 #include "backend/session/anf_runtime_algorithm.h"
30 #include "backend/optimizer/common/const_input_to_attr_registry.h"
31 #include "ir/func_graph_cloner.h"
32 #include "ir/func_graph.h"
33 #include "pipeline/jit/parse/python_adapter.h"
34 #include "pipeline/jit/action.h"
35 #include "utils/context/graph_kernel_flags.h"
36 #include "vm/segment_runner.h"
37 #if ENABLE_D
38 #include "runtime/device/ascend/kernel_select_ascend.h"
39 #elif ENABLE_GPU
40 #include "runtime/device/gpu/kernel_info_setter.h"
41 #endif
42
43 namespace mindspore {
44 namespace opt {
45 namespace {
IsMakeTupleOut(const AnfNodePtr & out,AnfNodePtrList * real_outs)46 bool IsMakeTupleOut(const AnfNodePtr &out, AnfNodePtrList *real_outs) {
47 MS_EXCEPTION_IF_NULL(real_outs);
48 if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
49 auto &inputs = out->cast<CNodePtr>()->inputs();
50 for (size_t i = 1; i < inputs.size(); ++i) {
51 real_outs->push_back(inputs[i]);
52 }
53 return true;
54 }
55
56 if (auto fg = AnfAlgo::GetCNodeFuncGraphPtr(out); fg != nullptr) {
57 auto fg_out = fg->output();
58 if (IsPrimitiveCNode(fg_out, prim::kPrimMakeTuple)) {
59 auto inputs = fg_out->cast<CNodePtr>()->inputs();
60 for (size_t i = 1; i < inputs.size(); ++i) {
61 real_outs->push_back(inputs[i]);
62 }
63 return true;
64 }
65 }
66 return false;
67 }
68
EliminateMakeTuple(const FuncGraphPtr & fg,const FuncGraphManagerPtr & mng)69 AnfNodePtrList EliminateMakeTuple(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
70 AnfNodePtrList outs;
71 auto out_node = fg->output();
72 if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
73 std::vector<AnfNodePtr> output_args;
74 auto out_cnode = out_node->cast<CNodePtr>();
75 for (auto out : out_cnode->inputs()) {
76 if (IsPrimitiveCNode(out, prim::kPrimMakeTuple)) {
77 auto inputs = out->cast<CNodePtr>()->inputs();
78 for (size_t i = 1; i < inputs.size(); ++i) {
79 output_args.push_back(inputs[i]);
80 }
81 } else {
82 output_args.push_back(out);
83 }
84 }
85 if (output_args.size() != out_cnode->inputs().size()) {
86 auto new_out = fg->NewCNode(output_args);
87 mng->Replace(out_node, new_out);
88 }
89
90 for (size_t i = 1; i < output_args.size(); ++i) {
91 outs.push_back(output_args[i]);
92 }
93 return outs;
94 }
95
96 outs.push_back(out_node);
97 return outs;
98 }
99
GenJson(const AnfNodePtrList & op_nodes,const std::pair<AnfNodePtrList,AnfNodePtrList> & in_and_out,const DumpOption & dump_option,nlohmann::json * op_desc,std::map<std::string,AnfNodePtr> * address_node_map=nullptr)100 bool GenJson(const AnfNodePtrList &op_nodes, const std::pair<AnfNodePtrList, AnfNodePtrList> &in_and_out,
101 const DumpOption &dump_option, nlohmann::json *op_desc,
102 std::map<std::string, AnfNodePtr> *address_node_map = nullptr) {
103 kernel::AkgKernelJsonGenerator akg_kernel_json_generator(dump_option);
104 if (!akg_kernel_json_generator.CollectFusedJson(op_nodes, in_and_out.first, in_and_out.second)) {
105 MS_LOG(ERROR) << "Collect json desc failed.";
106 return false;
107 }
108
109 *op_desc = akg_kernel_json_generator.kernel_json();
110 if (address_node_map != nullptr) {
111 *address_node_map = akg_kernel_json_generator.address_node_map();
112 }
113 std::string fused_name;
114 std::for_each(op_nodes.begin(), op_nodes.end(), [&fused_name](const AnfNodePtr &node) {
115 (void)fused_name.append(AnfAlgo::GetCNodeName(node)).append("_");
116 });
117 MS_LOG(DEBUG) << "Collect fusion json: " << fused_name;
118 return true;
119 }
120
ConvertToScalarTensor(const AnfNodePtr & value_node)121 AnfNodePtr ConvertToScalarTensor(const AnfNodePtr &value_node) {
122 auto tensor = GetValueNode<tensor::TensorPtr>(value_node);
123 MS_EXCEPTION_IF_NULL(tensor);
124 auto type_id = tensor->data_type();
125 ShapeVector new_shape;
126 auto origin_ndim = IntToSize(tensor->DataDim());
127 for (size_t i = 0; i < origin_ndim; ++i) {
128 new_shape.push_back(1);
129 }
130 tensor::TensorPtr scalar_tensor = std::make_shared<tensor::Tensor>(type_id, new_shape);
131 scalar_tensor->set_device_info(tensor->device_info());
132 auto data_ptr = scalar_tensor->data_c();
133 MS_EXCEPTION_IF_NULL(data_ptr);
134 auto itemsize = static_cast<size_t>(tensor->data().itemsize());
135 if (memcpy_s(data_ptr, static_cast<size_t>(itemsize), tensor->data_c(), itemsize) != 0) {
136 MS_LOG(EXCEPTION) << "Failed to copy data from tensor into scalar.";
137 }
138
139 ValueNodePtr new_value_node = std::make_shared<ValueNode>(scalar_tensor);
140 new_value_node->set_abstract(scalar_tensor->ToAbstract());
141 new_value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
142 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
143 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{GetFormat(value_node)});
144 kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{type_id});
145 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
146
147 return new_value_node;
148 }
149
ReplaceTensorWithScalar(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & scalar_tensors)150 void ReplaceTensorWithScalar(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &scalar_tensors) {
151 MS_EXCEPTION_IF_NULL(fg);
152 if (scalar_tensors.empty()) {
153 return;
154 }
155
156 auto sub_mng = fg->manager();
157 if (sub_mng == nullptr) {
158 sub_mng = Manage(fg, true);
159 fg->set_manager(sub_mng);
160 }
161
162 std::map<AnfNodePtr, AnfNodePtr> to_be_replaced;
163 for (auto scalar_tensor_node : scalar_tensors) {
164 auto scalar = ConvertToScalarTensor(scalar_tensor_node);
165 auto format = GetFormat(scalar_tensor_node);
166 auto dst_shape_vec = GetShape(scalar_tensor_node);
167 AnfNodePtrList new_broadcast_inputs = {NewValueNode(prim::kPrimBroadcastTo), scalar};
168 auto broadcast_node = CreateCNode(new_broadcast_inputs, fg,
169 {.format = format, .shape = dst_shape_vec, .type = GetType(scalar_tensor_node)});
170 auto device_shape = GetDeviceShape(scalar_tensor_node);
171 SetNodeAttrSafely("shape", MakeValue(device_shape), broadcast_node);
172 to_be_replaced[scalar_tensor_node] = broadcast_node;
173 }
174
175 for (auto [old_value_node, new_node] : to_be_replaced) {
176 sub_mng->Replace(old_value_node, new_node);
177 }
178 }
179 } // namespace
180
GetOutputAbstract(const AnfNodePtr & node,size_t output_idx)181 AbstractBasePtr GetOutputAbstract(const AnfNodePtr &node, size_t output_idx) {
182 auto out_spec = node->abstract();
183 if (out_spec->isa<abstract::AbstractTuple>()) {
184 return out_spec->cast<abstract::AbstractTuplePtr>()->elements()[output_idx];
185 }
186 return out_spec;
187 }
188
ConvertNonscalarTensorToParameter(const FuncGraphPtr & fg,AnfNodePtrList * inputs_ptr)189 bool ConvertNonscalarTensorToParameter(const FuncGraphPtr &fg, AnfNodePtrList *inputs_ptr) {
190 MS_EXCEPTION_IF_NULL(inputs_ptr);
191 auto nodes = TopoSort(fg->get_return());
192
193 std::vector<std::pair<tensor::TensorPtr, AnfNodePtrList>> v_replace;
194 std::vector<AnfNodePtr> scalar_tensors;
195 for (const auto &node : nodes) {
196 if (!node->isa<CNode>()) {
197 continue;
198 }
199 auto &inputs = node->cast<CNodePtr>()->inputs();
200 for (size_t i = 1; i < inputs.size(); ++i) {
201 const auto &tnode = inputs[i];
202 auto tensor = GetValueNode<tensor::TensorPtr>(tnode);
203 if (tensor == nullptr || tensor->DataSize() == 1) {
204 continue;
205 }
206 auto tensor_iter = std::find_if(
207 v_replace.begin(), v_replace.end(),
208 [&tensor](const std::pair<tensor::TensorPtr, AnfNodePtrList> &vl) { return vl.first->ValueEqual(*tensor); });
209 if (tensor_iter == v_replace.end()) {
210 (void)v_replace.emplace_back(tensor, AnfNodePtrList{tnode});
211 } else {
212 tensor_iter->second.push_back(tnode);
213 }
214 }
215 }
216
217 ReplaceTensorWithScalar(fg, scalar_tensors);
218
219 if (v_replace.empty()) {
220 return false;
221 }
222
223 auto mng = fg->manager();
224 if (mng == nullptr) {
225 mng = Manage(fg, false);
226 fg->set_manager(mng);
227 }
228
229 auto &inputs = *inputs_ptr;
230 for (auto iter : v_replace) {
231 auto value_nodes = iter.second;
232 if (value_nodes.empty()) {
233 MS_LOG(EXCEPTION) << "Invalid value in map!";
234 }
235
236 auto vnode = value_nodes[0];
237 auto parameter = fg->add_parameter();
238 parameter->set_abstract(vnode->abstract());
239 parameter->set_kernel_info(vnode->kernel_info_ptr());
240 for (const auto &value_node : value_nodes) {
241 mng->Replace(value_node, parameter);
242 }
243
244 inputs.push_back(vnode);
245 }
246
247 return true;
248 }
249
250 // Transform nodes(including basic and composite node) to a new graph, and collect their inputs and outputs.
MixedNodesTransToGraph(const AnfNodePtrList & fuse_nodes,AnfNodePtrList * src_outputs)251 std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> MixedNodesTransToGraph(const AnfNodePtrList &fuse_nodes,
252 AnfNodePtrList *src_outputs) {
253 FuncGraphPtr fg;
254 AnfNodePtrList inputs;
255 AnfNodePtrList outputs;
256 AnfNodePtrList *soutputs = (src_outputs != nullptr) ? src_outputs : &outputs;
257 std::tie(fg, inputs, *soutputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
258
259 FuncGraphManagerPtr mng = fg->manager();
260 if (mng == nullptr) {
261 mng = Manage(fg, false);
262 fg->set_manager(mng);
263 }
264
265 // Inline origin graphkernel
266 auto cnodes = fg->GetOrderedCnodes();
267 for (const auto &n : cnodes) {
268 if (!AnfAlgo::IsGraphKernel(n)) {
269 continue;
270 }
271 auto graph_kernel_g = GetValueNode<FuncGraphPtr>(n->input(0));
272 AnfNodePtrList ins;
273 ins.insert(ins.end(), n->inputs().begin() + 1, n->inputs().end());
274 auto out = InlineClone(graph_kernel_g, fg, ins, n->input(0)->scope());
275 mng->Replace(n, out);
276 }
277
278 EliminateMakeTuple(fg, mng);
279 ConvertNonscalarTensorToParameter(fg, &inputs);
280
281 outputs.clear();
282 kernel::GetFuncGraphOutputNodes(fg, &outputs);
283 return std::make_tuple(fg, inputs, outputs);
284 }
285
286 // Rebuild as node inputs or outputs have changed, processor comes from node itself
BuildSelectKernelBuildInfo(const std::vector<std::string> & inputs_format,const std::vector<TypeId> & inputs_type,const std::vector<std::string> & output_formats,const std::vector<TypeId> & output_types,const AnfNodePtr & node)287 kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
288 const std::vector<TypeId> &inputs_type,
289 const std::vector<std::string> &output_formats,
290 const std::vector<TypeId> &output_types, const AnfNodePtr &node) {
291 kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
292 graph_info_builder.SetInputsFormat(inputs_format);
293 graph_info_builder.SetInputsDeviceType(inputs_type);
294 graph_info_builder.SetOutputsFormat(output_formats);
295 graph_info_builder.SetOutputsDeviceType(output_types);
296 graph_info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
297 graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
298 graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
299 return graph_info_builder.Build();
300 }
301
302 // Build for new node, processor comes from context
BuildSelectKernelBuildInfo(const std::vector<std::string> & inputs_format,const std::vector<TypeId> & inputs_type,const std::vector<std::string> & output_formats,const std::vector<TypeId> & output_types)303 kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::string> &inputs_format,
304 const std::vector<TypeId> &inputs_type,
305 const std::vector<std::string> &output_formats,
306 const std::vector<TypeId> &output_types) {
307 kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
308 graph_info_builder.SetInputsFormat(inputs_format);
309 graph_info_builder.SetInputsDeviceType(inputs_type);
310 graph_info_builder.SetOutputsFormat(output_formats);
311 graph_info_builder.SetOutputsDeviceType(output_types);
312 graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
313 graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
314 graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
315 return graph_info_builder.Build();
316 }
317
SetNewKernelInfo(const AnfNodePtr & new_node,const FuncGraphPtr & fg,const AnfNodePtrList & inputs,const AnfNodePtrList & outputs)318 void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
319 const AnfNodePtrList &outputs) {
320 std::vector<std::string> graph_input_format;
321 std::vector<TypeId> graph_input_type;
322 std::vector<std::string> graph_output_format;
323 std::vector<TypeId> graph_output_type;
324 for (size_t i = 0; i < inputs.size(); ++i) {
325 auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
326 if (kernel_with_index.first->isa<ValueNode>()) {
327 auto tensor = GetValueNode<tensor::TensorPtr>(kernel_with_index.first);
328 MS_EXCEPTION_IF_NULL(tensor);
329 (void)graph_input_format.emplace_back(kOpFormat_DEFAULT);
330 (void)graph_input_type.emplace_back(tensor->data_type());
331 } else {
332 auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
333 (void)graph_input_format.emplace_back(std::move(input_format));
334 auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
335 (void)graph_input_type.emplace_back(input_type);
336 }
337 auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
338 fg->parameters()[i]->set_abstract(input_abs);
339 fg->parameters()[i]->set_kernel_info(std::make_shared<device::KernelInfo>());
340 kernel::KernelBuildInfo::KernelBuildInfoBuilder para_info_builder;
341 para_info_builder.SetOutputsFormat({graph_input_format.back()});
342 para_info_builder.SetOutputsDeviceType({graph_input_type.back()});
343 para_info_builder.SetKernelType(KernelType::AKG_KERNEL);
344 para_info_builder.SetProcessor(kernel::GetProcessorFromContext());
345 AnfAlgo::SetSelectKernelBuildInfo(para_info_builder.Build(), fg->parameters()[i].get());
346 }
347 auto new_outputs = outputs;
348 if (outputs.size() == 1 && AnfAlgo::IsGraphKernel(outputs[0])) {
349 std::vector<AnfNodePtr> real_outs;
350 if (IsMakeTupleOut(outputs[0], &real_outs)) {
351 new_outputs = real_outs;
352 }
353 }
354 for (size_t i = 0; i < new_outputs.size(); ++i) {
355 auto kernel_with_index = AnfAlgo::VisitKernel(new_outputs[i], 0);
356 auto output_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
357 auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
358 graph_output_format.push_back(output_format);
359 graph_output_type.push_back(output_type);
360 }
361 kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
362 graph_info_builder.SetInputsFormat(graph_input_format);
363 graph_info_builder.SetInputsDeviceType(graph_input_type);
364 graph_info_builder.SetOutputsFormat(graph_output_format);
365 graph_info_builder.SetOutputsDeviceType(graph_output_type);
366 graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
367 graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
368 graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
369 auto graph_selected_info = graph_info_builder.Build();
370 AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
371 }
372
CreateNewFuseCNode(const FuncGraphPtr & func_graph,const FuncGraphPtr & fg,const AnfNodePtrList & inputs,const AnfNodePtrList & outputs)373 AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &func_graph, const FuncGraphPtr &fg, const AnfNodePtrList &inputs,
374 const AnfNodePtrList &outputs) {
375 auto func_node = NewValueNode(fg);
376 std::vector<AnfNodePtr> fn_inputs;
377 fn_inputs.push_back(func_node);
378 fn_inputs.insert(fn_inputs.end(), inputs.begin(), inputs.end());
379 auto fuse_cnode = func_graph->NewCNode(fn_inputs);
380 // Set output abstract
381 if (outputs.size() > 1) {
382 std::vector<AbstractBasePtr> out_specs;
383 for (size_t i = 0; i < outputs.size(); ++i) {
384 out_specs.push_back(outputs[i]->abstract());
385 }
386 auto out_spec = std::make_shared<abstract::AbstractTuple>(out_specs);
387 fuse_cnode->set_abstract(out_spec);
388 } else {
389 fuse_cnode->set_abstract(outputs[0]->abstract());
390 }
391 // Set parameter abstract.
392 for (size_t i = 0; i < inputs.size(); ++i) {
393 auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
394 auto input_abs = GetOutputAbstract(kernel_with_index.first, kernel_with_index.second);
395 fg->parameters()[i]->set_abstract(input_abs);
396 }
397 return fuse_cnode;
398 }
399
ReplaceNewFuseCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & new_fuse_cnode,const AnfNodePtrList & outputs)400 void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_fuse_cnode,
401 const AnfNodePtrList &outputs) {
402 MS_EXCEPTION_IF_NULL(func_graph);
403 auto mng = func_graph->manager();
404 MS_EXCEPTION_IF_NULL(mng);
405 // single out
406 if (outputs.size() == 1) {
407 mng->Replace(outputs[0], new_fuse_cnode);
408 return;
409 }
410
411 std::vector<AnfNodePtr> fn_inputs;
412 size_t offset = 0;
413 for (size_t out_idx = 0; out_idx < outputs.size(); out_idx++) {
414 AnfNodePtrList real_outs;
415 // not make tuple out, replace
416 if (!IsMakeTupleOut(outputs[out_idx], &real_outs)) {
417 fn_inputs.clear();
418 fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
419 fn_inputs.push_back(new_fuse_cnode);
420 fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx + offset))));
421 auto new_out = func_graph->NewCNode(fn_inputs);
422 new_out->set_abstract(outputs[out_idx]->abstract());
423 mng->Replace(outputs[out_idx], new_out);
424 continue;
425 }
426
427 // the out is make tuple , modify the get_item node's value
428 auto users = mng->node_users()[outputs[out_idx]];
429 for (auto &user : users) {
430 auto use_node = user.first;
431 if (!use_node->isa<CNode>() || !IsPrimitiveCNode(use_node, prim::kPrimTupleGetItem)) {
432 continue;
433 }
434 auto get_item_cnode = use_node->cast<CNodePtr>();
435 auto value_input = get_item_cnode->input(kInputNodeOutputIndexInTupleGetItem);
436 MS_EXCEPTION_IF_NULL(value_input);
437 auto value_node = value_input->cast<ValueNodePtr>();
438 MS_EXCEPTION_IF_NULL(value_node);
439 auto item_idx = GetValue<int64_t>(value_node->value());
440 int64_t new_item_idx = SizeToLong(out_idx + offset) + item_idx;
441 fn_inputs.clear();
442 fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
443 fn_inputs.push_back(new_fuse_cnode);
444 fn_inputs.push_back(NewValueNode(new_item_idx));
445 auto new_out = func_graph->NewCNode(fn_inputs);
446 new_out->set_abstract(get_item_cnode->abstract());
447 mng->Replace(get_item_cnode, new_out);
448 }
449
450 offset += real_outs.size() - 1;
451 }
452 }
453
FuseNodesToSubGraph(const std::vector<AnfNodePtr> & fuse_nodes,const FuncGraphPtr & kernel_graph,const std::string & postfix)454 std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
455 const FuncGraphPtr &kernel_graph,
456 const std::string &postfix) {
457 auto mng = kernel_graph->manager();
458 if (mng == nullptr) {
459 mng = Manage(kernel_graph, true);
460 kernel_graph->set_manager(mng);
461 }
462
463 FuncGraphPtr fg;
464 AnfNodePtrList inputs;
465 AnfNodePtrList src_outputs;
466 AnfNodePtrList outputs;
467
468 std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(fuse_nodes, &src_outputs);
469 auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
470 SetNewKernelInfo(fuse_new_node, fg, inputs, outputs);
471 // Handle get-item probleam.
472 ReplaceNewFuseCNode(kernel_graph, fuse_new_node, src_outputs);
473
474 // set graphKernel attr
475 std::string fuse_op_name = "";
476 for (auto &fuse_node : fuse_nodes) {
477 if (IsPrimitiveCNode(fuse_node)) {
478 fuse_op_name += AnfAlgo::GetCNodePrimitive(fuse_node)->name() + "_";
479 } else if (AnfAlgo::IsGraphKernel(fuse_node)) {
480 auto fuse_cnode = fuse_node->cast<CNodePtr>();
481 MS_EXCEPTION_IF_NULL(fuse_cnode);
482 auto graph_kernel_fg = GetValueNode<FuncGraphPtr>(fuse_cnode->input(kAnfPrimitiveIndex));
483 auto fg_flag_val = graph_kernel_fg->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
484 auto fuse_fg_name = GetValue<std::string>(fg_flag_val);
485 fuse_op_name += fuse_fg_name + "_";
486 }
487 }
488 fuse_op_name += postfix;
489 fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
490
491 return std::make_tuple(fuse_new_node, src_outputs);
492 }
493
AnfToJsonDesc(const AnfNodePtrList & nodes,const DumpOption & dump_option,nlohmann::json * op_desc,std::map<std::string,AnfNodePtr> * address_node_map)494 bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
495 std::map<std::string, AnfNodePtr> *address_node_map) {
496 MS_EXCEPTION_IF_NULL(op_desc);
497 if (nodes.empty()) {
498 MS_LOG(ERROR) << "Input nodes is empty.";
499 return false;
500 }
501 bool has_graph_kernel = std::any_of(nodes.begin(), nodes.end(), AnfAlgo::IsGraphKernel);
502 bool is_single_graph_kernel = has_graph_kernel && nodes.size() == 1;
503
504 FuncGraphPtr fg;
505 AnfNodePtrList op_nodes, inputs, outputs;
506 if (is_single_graph_kernel) {
507 fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
508 kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
509 } else if (!has_graph_kernel) {
510 std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(nodes);
511 op_nodes = nodes;
512 } else {
513 // When there are basic and composite ops, the composite ops should be inline to the basic ones' graph,
514 // so a new graph generation should be done (because they may in the main graph!).
515 // If address_node_map is wanted, we should map the new node in new graph to the old nodes. But... not support now.
516 MS_LOG(EXCEPTION) << "No support mixed with basic and composite ops now!";
517 }
518 std::pair<AnfNodePtrList, AnfNodePtrList> in_and_out = std::make_pair(inputs, outputs);
519 return GenJson(op_nodes, in_and_out, dump_option, op_desc, address_node_map);
520 }
521
AnfToJsonDesc(const AnfNodePtrList & nodes,const DumpOption & dump_option,nlohmann::json * op_desc)522 bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc) {
523 MS_EXCEPTION_IF_NULL(op_desc);
524 if (nodes.empty()) {
525 MS_LOG(ERROR) << "Input nodes is empty.";
526 return false;
527 }
528
529 FuncGraphPtr fg;
530 AnfNodePtrList op_nodes, inputs, outputs;
531 if (nodes.size() == 1 && AnfAlgo::IsGraphKernel(nodes[0])) {
532 fg = AnfAlgo::GetCNodeFuncGraphPtr(nodes[0]);
533 } else {
534 std::tie(fg, inputs, outputs) = MixedNodesTransToGraph(nodes);
535 inputs.clear();
536 outputs.clear();
537 }
538
539 kernel::GetValidKernelNodes(fg, &op_nodes, &inputs, &outputs);
540
541 auto mng = fg->manager();
542 if (mng == nullptr) {
543 mng = Manage(fg, false);
544 fg->set_manager(mng);
545 }
546 std::pair<AnfNodePtrList, AnfNodePtrList> in_and_out = std::make_pair(inputs, outputs);
547 return GenJson(op_nodes, in_and_out, dump_option, op_desc);
548 }
549
AnfToJsonDesc(const std::vector<AnfNodePtrList> & graphs,const DumpOption & dump_option,nlohmann::json * op_desc)550 bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc) {
551 MS_EXCEPTION_IF_NULL(op_desc);
552 std::vector<nlohmann::json> graphs_desc;
553 for (auto const &graph_nodes : graphs) {
554 nlohmann::json desc;
555 if (!AnfToJsonDesc(graph_nodes, dump_option, &desc)) {
556 MS_LOG(ERROR) << "Collect json desc failed.";
557 return false;
558 }
559 graphs_desc.push_back(desc);
560 }
561 if (graphs_desc.empty()) {
562 MS_LOG(ERROR) << "Collect zero json desc.";
563 return false;
564 }
565
566 if (graphs_desc.size() > 1) {
567 nlohmann::json op_json_desc;
568 op_json_desc[kJsonKeyMultiGraph] = true;
569 op_json_desc[kJsonKeyGraphDesc] = graphs_desc;
570 *op_desc = op_json_desc;
571 return true;
572 }
573
574 *op_desc = graphs_desc[0];
575 return true;
576 }
577
JsonDescToAnf(const std::string & json_desc)578 FuncGraphPtr JsonDescToAnf(const std::string &json_desc) {
579 kernel::AkgKernelJsonDecoder akg_kernel_json_decoder;
580 auto fg = akg_kernel_json_decoder.DecodeFusedNodes(json_desc);
581 if (fg == nullptr) {
582 MS_LOG(ERROR) << "Akg decode json to graph failed.";
583 return nullptr;
584 }
585 return fg;
586 }
587
ExtractGraphKernelName(const AnfNodePtrList & cnodes,const string & prefix,const string & postfix)588 std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix, const string &postfix) {
589 std::stringstream name;
590 if (prefix != "") {
591 name << prefix << "_";
592 }
593 for (const auto &node : cnodes) {
594 if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
595 name << AnfAlgo::GetCNodeName(node) << "_";
596 }
597 }
598 if (postfix != "") {
599 name << postfix;
600 }
601 return name.str();
602 }
603
ResetKernelInfo(const AnfNodePtr & node,KernelType kernel_type)604 void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
605 auto cnode = node->cast<CNodePtr>();
606 MS_EXCEPTION_IF_NULL(cnode);
607 #if ENABLE_D
608 device::ascend::SetKernelInfo(cnode, kernel_type);
609 #elif ENABLE_GPU
610 cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
611 device::gpu::SetKernelInfo(cnode, kernel_type);
612 #endif
613 }
614
GetFormat(const AnfNodePtr & node)615 std::string GetFormat(const AnfNodePtr &node) { return AnfAlgo::GetOutputFormat(node, 0); }
616
GetType(const AnfNodePtr & node)617 TypePtr GetType(const AnfNodePtr &node) {
618 const auto &abstract = node->abstract();
619 auto type = abstract->BuildType();
620 MS_EXCEPTION_IF_NULL(type);
621 return type;
622 }
623
GetShape(const AnfNodePtr & node)624 ShapeVector GetShape(const AnfNodePtr &node) {
625 auto abstract = node->abstract();
626 MS_EXCEPTION_IF_NULL(abstract);
627 auto shape = abstract->GetShapeTrack();
628 if (shape == nullptr || !shape->isa<abstract::Shape>()) {
629 MS_LOG(EXCEPTION) << "Cannot get shape from " << node->fullname_with_scope();
630 }
631 auto shape_vec = shape->cast<abstract::ShapePtr>()->shape();
632 if (shape_vec.empty()) {
633 shape_vec.push_back(1);
634 }
635 return shape_vec;
636 }
637
GetDeviceShape(const AnfNodePtr & node)638 ShapeVector GetDeviceShape(const AnfNodePtr &node) {
639 ShapeVector res_device_shape;
640 auto device_shape = AnfAlgo::GetOutputDeviceShape(node, 0);
641 if (device_shape.empty()) {
642 res_device_shape.push_back(1);
643 } else {
644 (void)std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(res_device_shape), SizeToLong);
645 }
646 return res_device_shape;
647 }
648
GetReduceAxis(const AnfNodePtr & node)649 std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {
650 auto prim = GetCNodePrimitive(node);
651 MS_EXCEPTION_IF_NULL(prim);
652 const auto &attrs = prim->attrs();
653 auto iter = attrs.find("axis");
654 if (iter == attrs.end()) {
655 MS_LOG(EXCEPTION) << "Origin node have no attributes!";
656 }
657
658 std::vector<int64_t> axis;
659
660 auto &v = iter->second;
661 if (v->isa<ValueList>() || v->isa<ValueTuple>()) {
662 auto vec = v->isa<ValueList>() ? v->cast<ValueListPtr>()->value() : v->cast<ValueTuplePtr>()->value();
663 for (auto value : vec) {
664 if (value->isa<Int64Imm>()) {
665 axis.push_back(GetValue<int64_t>(value));
666 } else {
667 MS_LOG(EXCEPTION) << "Reduce axis type should be int64!";
668 }
669 }
670 } else if (v->isa<Int64Imm>()) {
671 axis.push_back(GetValue<int64_t>(v));
672 } else {
673 MS_LOG(EXCEPTION) << "Reduce axis should be a list or tuple!";
674 }
675
676 return axis;
677 }
678
CreateCNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & func_graph,const DataInfo & out_info,bool use_fake_abstract)679 CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info,
680 bool use_fake_abstract) {
681 // Limitation: 1. Node's attributes should be set out of this function; 2. only one output.
682 MS_EXCEPTION_IF_NULL(out_info.type);
683 auto out_type = out_info.type;
684 if (auto otype = out_info.type->cast<TensorTypePtr>(); otype != nullptr) {
685 out_type = otype->element();
686 }
687
688 // Create CNode.
689 auto cnode = func_graph->NewCNode(inputs);
690 MS_EXCEPTION_IF_NULL(cnode);
691
692 // Setup abstract.
693 if (use_fake_abstract) {
694 auto abs_shape = kernel::GetFakeAbstractShape(out_info.shape, out_info.format);
695 auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, abs_shape);
696 cnode->set_abstract(abs_tensor);
697 } else {
698 auto abs_tensor = std::make_shared<abstract::AbstractTensor>(out_type, out_info.shape);
699 cnode->set_abstract(abs_tensor);
700 }
701
702 // Setup kernel info.
703 auto kernel_info = std::make_shared<device::KernelInfo>();
704 cnode->set_kernel_info(kernel_info);
705 std::vector<size_t> feature_map_input_indexs;
706 kernel_info->set_feature_map_flag(false);
707 for (size_t i = 1; i < inputs.size(); ++i) {
708 if (AnfAlgo::IsFeatureMapOutput(inputs[i])) {
709 kernel_info->set_feature_map_flag(true);
710 feature_map_input_indexs.push_back(i);
711 }
712 }
713 if (inputs.size() == 1) {
714 kernel_info->set_feature_map_flag(true);
715 }
716 if (AnfAlgo::IsRealKernel(cnode)) {
717 // if the node only has the primitive(such as getNext) or the node's input has a feature map input
718 // then the node's output is a feature map output
719 SetNodeAttrSafely(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
720 SetNodeAttrSafely(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
721 }
722
723 // Setup kernel build info.
724 std::vector<std::string> input_formats;
725 std::vector<TypeId> input_types;
726 for (size_t i = 1; i < inputs.size(); ++i) {
727 auto kernel_with_index = AnfAlgo::VisitKernel(inputs[i], 0);
728 auto input_format = AnfAlgo::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
729 input_formats.push_back(input_format);
730 auto input_type = AnfAlgo::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
731 input_types.push_back(input_type);
732 }
733
734 std::vector<std::string> output_formats = {out_info.format};
735 std::vector<TypeId> output_types = {out_type->type_id()};
736
737 kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
738 info_builder.SetInputsFormat(input_formats);
739 info_builder.SetInputsDeviceType(input_types);
740 info_builder.SetOutputsFormat(output_formats);
741 info_builder.SetOutputsDeviceType(output_types);
742 info_builder.SetProcessor(kernel::GetProcessorFromContext());
743 info_builder.SetKernelType(KernelType::AKG_KERNEL);
744 info_builder.SetFusionType(kernel::FusionType::OPAQUE);
745 auto selected_info = info_builder.Build();
746 AnfAlgo::SetSelectKernelBuildInfo(selected_info, cnode.get());
747
748 func_graph->AddNode(cnode);
749 return cnode;
750 }
751
SetNodeAttrSafely(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)752 void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
753 // Make CNode safe to set attr firstly.
754 auto cnode = node->cast<CNodePtr>();
755 if (cnode == nullptr) {
756 return;
757 }
758 AnfNodePtrList new_inputs = {NewValueNode(AnfAlgo::GetCNodePrimitive(cnode)->Clone())};
759 auto inputs = cnode->inputs();
760 new_inputs.insert(new_inputs.end(), inputs.begin() + 1, inputs.end());
761 cnode->set_inputs(new_inputs);
762
763 // Set attr secondly.
764 AnfAlgo::SetNodeAttr(key, value, node);
765 }
766
IsKeepBasicNode(const AnfNodePtr & node)767 bool IsKeepBasicNode(const AnfNodePtr &node) {
768 MS_EXCEPTION_IF_NULL(node);
769 if (!node->isa<CNode>()) {
770 return false;
771 }
772 auto cnode = node->cast<CNodePtr>();
773 MS_EXCEPTION_IF_NULL(cnode);
774
775 // Dynamic shape is unsupported yet.
776 if (AnfAlgo::HasDynamicShapeFlag(AnfAlgo::GetCNodePrimitive(cnode))) {
777 return true;
778 }
779
780 static const std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
781 "aggregate", "aggregate_input_indexx"};
782 // If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
783 if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
784 [&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); })) {
785 return true;
786 }
787 if (AnfAlgo::GetBooleanAttr(cnode, "skip")) {
788 return true;
789 }
790 return false;
791 }
792
OpListFilter(std::vector<PrimitivePtr> * ops,const std::vector<std::string> & enable_ops_only,const std::vector<std::string> & enable_ops,const std::vector<std::string> & disable_ops)793 void OpListFilter(std::vector<PrimitivePtr> *ops, const std::vector<std::string> &enable_ops_only,
794 const std::vector<std::string> &enable_ops, const std::vector<std::string> &disable_ops) {
795 auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
796 if (!enable_ops_only.empty()) {
797 ops->clear();
798 (void)std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::back_inserter(*ops), new_prim);
799 } else {
800 if (!enable_ops.empty()) {
801 (void)std::transform(enable_ops.begin(), enable_ops.end(), std::back_inserter(*ops), new_prim);
802 }
803 if (!disable_ops.empty()) {
804 auto iter = std::remove_if(ops->begin(), ops->end(), [&disable_ops](const PrimitivePtr &p) {
805 return std::find(disable_ops.begin(), disable_ops.end(), p->name()) != disable_ops.end();
806 });
807 (void)ops->erase(iter, ops->end());
808 }
809 }
810 }
811
AnfGraph2LiteGraph(const FuncGraphPtr & func_graph)812 graphkernel::LiteGraphPtr AnfGraph2LiteGraph(const FuncGraphPtr &func_graph) {
813 graphkernel::LiteGraph::GraphBuilder gb(GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)));
814 std::map<AnfNodePtr, graphkernel::NodePtr> node_map;
815 auto todos = TopoSort(func_graph->output());
816 const auto ¶ms = func_graph->parameters();
817 auto ExtractBuildInfo = [](const AnfNodePtr &node) {
818 auto shape = GetDeviceShape(node);
819 auto type = AnfAlgo::GetOutputDeviceDataType(node, 0);
820 auto format = AnfAlgo::GetOutputFormat(node, 0);
821 return graphkernel::NodeBase({shape, type, format});
822 };
823 // set inputs
824 for (size_t i = 0; i < params.size(); i++) {
825 node_map[params[i]] = gb.Parameter(ExtractBuildInfo(params[i]), std::string("input_") + std::to_string(i));
826 }
827 // set ops
828 for (auto node : todos) {
829 auto cnode = node->cast<CNodePtr>();
830 if (cnode == nullptr) continue;
831 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) break;
832 auto prim = AnfAlgo::GetCNodePrimitive(cnode);
833 MS_EXCEPTION_IF_NULL(prim);
834 graphkernel::NodePtrList inputs;
835 (void)std::transform(cnode->inputs().begin() + 1, cnode->inputs().end(), std::back_inserter(inputs),
836 [&node_map, &gb](const AnfNodePtr &no) {
837 auto iter = node_map.find(no);
838 if (iter != node_map.end()) {
839 return iter->second;
840 } else {
841 auto tensor = GetValueNode<tensor::TensorPtr>(no);
842 MS_EXCEPTION_IF_NULL(tensor);
843 return gb.Value(tensor);
844 }
845 });
846 node_map[node] = gb.Op(AnfAlgo::GetCNodeName(node), ExtractBuildInfo(node), inputs, prim->attrs());
847 }
848 // set outputs
849 auto output_node = func_graph->output();
850 if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
851 graphkernel::NodePtrList outputs;
852 auto mt = output_node->cast<CNodePtr>();
853 (void)std::transform(mt->inputs().begin() + 1, mt->inputs().end(), std::back_inserter(outputs),
854 [&node_map](const AnfNodePtr &no) { return node_map[no]; });
855 gb.SetOutputs(std::move(outputs));
856 } else {
857 gb.SetOutputs({node_map[output_node]});
858 }
859 return gb.Get();
860 }
861
LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr & lite_graph,AnfNodePtrList * outputs)862 FuncGraphPtr LiteGraph2AnfGraph(const graphkernel::LiteGraphPtr &lite_graph, AnfNodePtrList *outputs) {
863 auto func_graph = std::make_shared<FuncGraph>();
864 std::map<graphkernel::NodePtr, AnfNodePtr> node_map;
865 for (const auto &inp : lite_graph->inputs()) {
866 auto param = func_graph->add_parameter();
867 node_map[inp] = param;
868 auto abs_shape = kernel::GetFakeAbstractShape(inp->shape, inp->format);
869 param->set_abstract(std::make_shared<abstract::AbstractTensor>(TypeIdToType(inp->type), abs_shape));
870 param->set_kernel_info(std::make_shared<device::KernelInfo>());
871 auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
872 AnfAlgo::SetSelectKernelBuildInfo(build_info, param.get());
873 }
874 // Create CNodes.
875 for (const auto &op_node : lite_graph->GetOrderedNodes()) {
876 if (op_node->NodeType() != graphkernel::NType::Primitive) {
877 MS_LOG(EXCEPTION) << "Node " << op_node->name() << "should be a Primitive node";
878 }
879 auto op = std::static_pointer_cast<graphkernel::PrimOp>(op_node);
880 AnfNodePtrList inputs = {NewValueNode(std::make_shared<Primitive>(op->op(), op->attrs()))};
881 (void)std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(inputs),
882 [&node_map](const graphkernel::NodePtr &inp) -> AnfNodePtr {
883 auto iter = node_map.find(inp);
884 if (iter != node_map.end()) {
885 return iter->second;
886 } else {
887 if (inp->NodeType() != graphkernel::NType::Value) {
888 MS_LOG(EXCEPTION) << "Node " << inp->name() << "should be a Value node";
889 }
890 auto inp_value = inp->As<graphkernel::ConstTensorNode>()->data();
891 auto value_node = NewValueNode(inp_value);
892 value_node->set_abstract(inp_value->ToAbstract());
893 value_node->set_kernel_info(std::make_shared<device::KernelInfo>());
894 auto build_info = BuildSelectKernelBuildInfo({}, {}, {inp->format}, {inp->type});
895 AnfAlgo::SetSelectKernelBuildInfo(build_info, value_node.get());
896 return value_node;
897 }
898 });
899 auto cnode = CreateCNode(inputs, func_graph, {op->format, op->shape, TypeIdToType(op->type)}, true);
900 MS_EXCEPTION_IF_NULL(cnode);
901 node_map[op_node] = cnode;
902 }
903 if (lite_graph->GetOutputs().empty()) {
904 MS_LOG(EXCEPTION) << "The output of LiteGraph " << lite_graph->name() << " is empty.";
905 } else if (lite_graph->GetOutputs().size() == 1) {
906 func_graph->set_output(node_map[lite_graph->GetOutputs()[0]]);
907 if (outputs != nullptr) {
908 (void)outputs->emplace_back(func_graph->output());
909 }
910 } else {
911 AnfNodePtrList mt_inputs;
912 AbstractBasePtrList out_abs_list;
913 (void)std::transform(lite_graph->GetOutputs().begin(), lite_graph->GetOutputs().end(),
914 std::back_inserter(mt_inputs), [&node_map, &out_abs_list](const graphkernel::NodePtr &out) {
915 auto out_node = node_map[out];
916 MS_EXCEPTION_IF_NULL(out_node);
917 (void)out_abs_list.emplace_back(out_node->abstract());
918 return out_node;
919 });
920 auto mt = func_graph->NewCNode(prim::kPrimMakeTuple, mt_inputs);
921 mt->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
922 mt->set_kernel_info(std::make_shared<device::KernelInfo>());
923 func_graph->AddNode(mt);
924 func_graph->set_output(mt);
925 if (outputs != nullptr) {
926 *outputs = std::move(mt_inputs);
927 }
928 }
929 return func_graph;
930 }
931
EliminateRedundantParameters(const FuncGraphPtr & func_graph,AnfNodePtrList * inputs)932 void EliminateRedundantParameters(const FuncGraphPtr &func_graph, AnfNodePtrList *inputs) {
933 const auto &ori_parameter = func_graph->parameters();
934 auto todos = TopoSort(func_graph->get_return());
935 std::set<AnfNodePtr> used_param;
936 for (auto node : todos) {
937 if (node->isa<Parameter>()) {
938 (void)used_param.insert(node);
939 }
940 }
941 if (used_param.size() == ori_parameter.size()) {
942 return;
943 }
944 AnfNodePtrList new_parameter, new_inputs;
945 for (size_t i = 0; i < ori_parameter.size(); ++i) {
946 if (used_param.count(ori_parameter[i])) {
947 new_parameter.push_back(ori_parameter[i]);
948 new_inputs.push_back((*inputs)[i]);
949 }
950 }
951 func_graph->set_parameters(new_parameter);
952 *inputs = std::move(new_inputs);
953 }
954
GetValidOps(const std::vector<std::tuple<std::string,unsigned int,PrimitivePtr>> & ops_with_level,unsigned int level)955 std::vector<PrimitivePtr> GetValidOps(
956 const std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> &ops_with_level, unsigned int level) {
957 auto context_ptr = MsContext::GetInstance();
958 MS_EXCEPTION_IF_NULL(context_ptr);
959 std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
960 std::vector<PrimitivePtr> valid_ops;
961 for (const auto &[op_target, op_level, op] : ops_with_level) {
962 if (op_target == kAllTarget || op_target == target) {
963 if (level >= op_level) {
964 (void)valid_ops.emplace_back(op);
965 }
966 }
967 }
968 return valid_ops;
969 }
970
GetFuncGraphManager(const FuncGraphPtr & func_graph)971 FuncGraphManagerPtr GetFuncGraphManager(const FuncGraphPtr &func_graph) {
972 MS_EXCEPTION_IF_NULL(func_graph);
973 FuncGraphManagerPtr manager = func_graph->manager();
974 if (manager == nullptr) {
975 manager = Manage(func_graph, true);
976 func_graph->set_manager(manager);
977 }
978 return manager;
979 }
980
UpdateMng(const FuncGraphManagerPtr & mng,const FuncGraphPtr & func_graph)981 void UpdateMng(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph) {
982 mng->RemoveRoots();
983 mng->KeepRoots({func_graph});
984 }
985 } // namespace opt
986 } // namespace mindspore
987