1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/core/graph_kernel_expander.h"
18
19 #include "utils/anf_utils.h"
20 #include "backend/common/graph_kernel/core/graph_builder.h"
21 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
22 #include "include/common/utils/anfalgo.h"
23
24 namespace mindspore::graphkernel {
CreateExpandedNode(const CNodePtr & node,const std::string & name) const25 AnfNodePtr GraphKernelExpander::CreateExpandedNode(const CNodePtr &node, const std::string &name) const {
26 auto new_fg = GetCNodeFuncGraph(node);
27 new_fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(name));
28 auto main_graph = node->func_graph();
29 std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end());
30 (void)ConvertTensorToParameter(new_fg, &inputs);
31 auto graph_kernel_node = CreateNewFuseCNode(main_graph, new_fg, inputs);
32 MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope()
33 << " with: " << graph_kernel_node->fullname_with_scope();
34 return graph_kernel_node;
35 }
36
37 static const std::map<std::string, std::vector<size_t>> ops = {{"ApplyMomentum", {1}}};
38
IsOuputNumInconsistent(const AnfNodePtr & node)39 bool IsOuputNumInconsistent(const AnfNodePtr &node) {
40 auto prim_name = GetCNodePrimitive(node)->name();
41 if (ops.find(prim_name) != ops.end()) {
42 return true;
43 }
44 return false;
45 }
46
ReplaceNodeWithTupleGetItem(const AnfNodePtr & node,const AnfNodePtr & newnode,const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng)47 void ReplaceNodeWithTupleGetItem(const AnfNodePtr &node, const AnfNodePtr &newnode, const FuncGraphPtr &func_graph,
48 const FuncGraphManagerPtr &mng) {
49 const auto &output_indices = ops.at(GetCNodePrimitive(node)->name());
50 if (output_indices.size() == 1) {
51 auto idx = MakeValue(SizeToLong(output_indices[0]));
52 AnfNodePtrList inputs{NewValueNode(prim::kPrimTupleGetItem), newnode, NewValueNode(idx)};
53 inputs.back()->set_abstract(idx->ToAbstract());
54 auto new_out = func_graph->NewCNode(inputs);
55 auto abs = newnode->abstract();
56 if (!abs->isa<abstract::AbstractSequence>()) {
57 MS_LOG(EXCEPTION) << "The output abstract has to be an abstract sequence";
58 }
59 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
60 auto elements = abs_seq->elements();
61 new_out->set_abstract(elements[output_indices[0]]);
62 mng->Replace(node, new_out);
63 } else {
64 MS_LOG(EXCEPTION) << "Unsupported at present";
65 }
66 }
67
DoExpand(const FuncGraphPtr & func_graph)68 bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
69 bool changed = false;
70 auto todos = TopoSort(func_graph->output());
71 auto mng = func_graph->manager();
72 MS_EXCEPTION_IF_NULL(mng);
73 for (const auto &n : todos) {
74 auto node = n->cast<CNodePtr>();
75 if (node != nullptr) {
76 PreProcessAllNode(node);
77 }
78 if (node == nullptr || AnfUtils::IsGraphKernel(node) || GkUtils::IsKeepBasicNode(node) ||
79 !AnfUtils::IsRealKernel(node) || !CanExpand(node)) {
80 continue;
81 }
82 MS_LOG(DEBUG) << "Expanding node run start: " << node->fullname_with_scope();
83 auto newnode = InitExpander(node)->Run(node);
84 MS_LOG(DEBUG) << "Expanding node run end: " << node->fullname_with_scope();
85 if (newnode == nullptr) {
86 MS_LOG(DEBUG) << "Skipped node: " << node->fullname_with_scope();
87 continue;
88 }
89 if (newnode->isa<CNode>()) {
90 newnode = CreateExpandedNode(newnode->cast<CNodePtr>(), AnfUtils::GetCNodeName(node));
91 }
92 if (newnode == nullptr) {
93 MS_LOG(DEBUG) << "Skipped node: " << node->fullname_with_scope();
94 continue;
95 }
96 // For some ops, the output number of expander is different from the original cnode. In this case, a TupleGetItem is
97 // needed to insure that later cnodes have correct input
98 if (IsOuputNumInconsistent(node)) {
99 ReplaceNodeWithTupleGetItem(node, newnode, func_graph, mng);
100 } else {
101 mng->Replace(node, newnode);
102 }
103 changed = true;
104 }
105 return changed;
106 }
107
Run(const FuncGraphPtr & func_graph)108 bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
109 expand_ops_ = InitOpList();
110 return DoExpand(func_graph);
111 }
112 } // namespace mindspore::graphkernel
113