• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/core/graph_kernel_expander.h"
18 
19 #include "utils/anf_utils.h"
20 #include "backend/common/graph_kernel/core/graph_builder.h"
21 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
22 #include "include/common/utils/anfalgo.h"
23 
24 namespace mindspore::graphkernel {
CreateExpandedNode(const CNodePtr & node,const std::string & name) const25 AnfNodePtr GraphKernelExpander::CreateExpandedNode(const CNodePtr &node, const std::string &name) const {
26   auto new_fg = GetCNodeFuncGraph(node);
27   new_fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(name));
28   auto main_graph = node->func_graph();
29   std::vector<AnfNodePtr> inputs(node->inputs().begin() + 1, node->inputs().end());
30   (void)ConvertTensorToParameter(new_fg, &inputs);
31   auto graph_kernel_node = CreateNewFuseCNode(main_graph, new_fg, inputs);
32   MS_LOG(DEBUG) << "Expand node: " << node->fullname_with_scope()
33                 << " with: " << graph_kernel_node->fullname_with_scope();
34   return graph_kernel_node;
35 }
36 
37 static const std::map<std::string, std::vector<size_t>> ops = {{"ApplyMomentum", {1}}};
38 
IsOuputNumInconsistent(const AnfNodePtr & node)39 bool IsOuputNumInconsistent(const AnfNodePtr &node) {
40   auto prim_name = GetCNodePrimitive(node)->name();
41   if (ops.find(prim_name) != ops.end()) {
42     return true;
43   }
44   return false;
45 }
46 
ReplaceNodeWithTupleGetItem(const AnfNodePtr & node,const AnfNodePtr & newnode,const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng)47 void ReplaceNodeWithTupleGetItem(const AnfNodePtr &node, const AnfNodePtr &newnode, const FuncGraphPtr &func_graph,
48                                  const FuncGraphManagerPtr &mng) {
49   const auto &output_indices = ops.at(GetCNodePrimitive(node)->name());
50   if (output_indices.size() == 1) {
51     auto idx = MakeValue(SizeToLong(output_indices[0]));
52     AnfNodePtrList inputs{NewValueNode(prim::kPrimTupleGetItem), newnode, NewValueNode(idx)};
53     inputs.back()->set_abstract(idx->ToAbstract());
54     auto new_out = func_graph->NewCNode(inputs);
55     auto abs = newnode->abstract();
56     if (!abs->isa<abstract::AbstractSequence>()) {
57       MS_LOG(EXCEPTION) << "The output abstract has to be an abstract sequence";
58     }
59     auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
60     auto elements = abs_seq->elements();
61     new_out->set_abstract(elements[output_indices[0]]);
62     mng->Replace(node, new_out);
63   } else {
64     MS_LOG(EXCEPTION) << "Unsupported at present";
65   }
66 }
67 
DoExpand(const FuncGraphPtr & func_graph)68 bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
69   bool changed = false;
70   auto todos = TopoSort(func_graph->output());
71   auto mng = func_graph->manager();
72   MS_EXCEPTION_IF_NULL(mng);
73   for (const auto &n : todos) {
74     auto node = n->cast<CNodePtr>();
75     if (node != nullptr) {
76       PreProcessAllNode(node);
77     }
78     if (node == nullptr || AnfUtils::IsGraphKernel(node) || GkUtils::IsKeepBasicNode(node) ||
79         !AnfUtils::IsRealKernel(node) || !CanExpand(node)) {
80       continue;
81     }
82     MS_LOG(DEBUG) << "Expanding node run start: " << node->fullname_with_scope();
83     auto newnode = InitExpander(node)->Run(node);
84     MS_LOG(DEBUG) << "Expanding node run end: " << node->fullname_with_scope();
85     if (newnode == nullptr) {
86       MS_LOG(DEBUG) << "Skipped node: " << node->fullname_with_scope();
87       continue;
88     }
89     if (newnode->isa<CNode>()) {
90       newnode = CreateExpandedNode(newnode->cast<CNodePtr>(), AnfUtils::GetCNodeName(node));
91     }
92     if (newnode == nullptr) {
93       MS_LOG(DEBUG) << "Skipped node: " << node->fullname_with_scope();
94       continue;
95     }
96     // For some ops, the output number of expander is different from the original cnode. In this case, a TupleGetItem is
97     // needed to insure that later cnodes have correct input
98     if (IsOuputNumInconsistent(node)) {
99       ReplaceNodeWithTupleGetItem(node, newnode, func_graph, mng);
100     } else {
101       mng->Replace(node, newnode);
102     }
103     changed = true;
104   }
105   return changed;
106 }
107 
Run(const FuncGraphPtr & func_graph)108 bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
109   expand_ops_ = InitOpList();
110   return DoExpand(func_graph);
111 }
112 }  // namespace mindspore::graphkernel
113