• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
18 
19 #include <string>
20 #include <set>
21 #include <utility>
22 #include <vector>
23 #include <tuple>
24 #include <algorithm>
25 
26 #include "utils/context/graph_kernel_flags.h"
27 #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
28 #include "backend/kernel_compiler/common_utils.h"
29 #include "backend/kernel_compiler/kernel_build_info.h"
30 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
31 #include "backend/optimizer/graph_kernel/split_umonad.h"
32 #include "backend/optimizer/graph_kernel/substitute_dropout.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "mindspore/core/ir/graph_utils.h"
35 #include "pipeline/jit/parse/python_adapter.h"
36 #include "pybind_api/ir/primitive_py.h"
37 #include "runtime/device/kernel_info.h"
38 #include "vm/segment_runner.h"
39 #include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
40 
41 namespace mindspore {
42 namespace opt {
43 namespace {
44 using context::OpLevel_0;
45 using context::OpLevel_1;
46 constexpr size_t kAssignInputIdx = 1;
47 constexpr size_t kLambOptimizerInputIdx = 12;
48 constexpr size_t kLambWeightInputIdx = 4;
49 constexpr size_t kRandomInputIdx = 1;
50 
GetExpandOps()51 std::vector<PrimitivePtr> GetExpandOps() {
52   std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> expand_ops_with_level = {
53     {kAllTarget, OpLevel_0, prim::kPrimAddN},
54     {kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
55     {kAllTarget, OpLevel_0, prim::kPrimErfc},
56     {kAllTarget, OpLevel_1, prim::kPrimExpandDims},
57     {kAllTarget, OpLevel_0, prim::kPrimGeLU},
58     {kAllTarget, OpLevel_0, prim::kPrimGeLUGrad},
59     {kAllTarget, OpLevel_0, prim::kPrimSquare},
60     {kAllTarget, OpLevel_0, prim::kPrimTile},
61     {kAscendDevice, OpLevel_0, prim::kLambApplyOptimizerAssign},
62     {kAscendDevice, OpLevel_0, prim::kLambApplyWeightAssign},
63     {kAscendDevice, OpLevel_0, prim::kPrimClipByNormNoDivSum},
64     {kAscendDevice, OpLevel_0, prim::kPrimSqrtGrad},
65     {kAscendDevice, OpLevel_1, prim::kSoftmaxGradExt},
66     {kAscendDevice, OpLevel_0, prim::kFusedMulAdd},
67     {kGPUDevice, OpLevel_1, prim::kPrimBatchMatMul},
68     {kGPUDevice, OpLevel_0, prim::kPrimBiasAdd},
69     {kGPUDevice, OpLevel_1, prim::kPrimBiasAddGrad},
70     {kGPUDevice, OpLevel_0, prim::kPrimDropout},
71     {kGPUDevice, OpLevel_0, prim::kPrimDropoutGrad},
72     {kGPUDevice, OpLevel_0, prim::kPrimFusedAdam},
73     {kGPUDevice, OpLevel_0, prim::kPrimFusedAdamWeightDecay},
74     {kGPUDevice, OpLevel_1, prim::kPrimMaximumGrad},
75     {kGPUDevice, OpLevel_1, prim::kPrimMinimumGrad},
76     {kGPUDevice, OpLevel_1, prim::kPrimLayerNorm},
77     {kGPUDevice, OpLevel_1, prim::kPrimLayerNormGrad},
78     {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmax},
79     {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmaxGrad},
80     {kGPUDevice, OpLevel_1, prim::kPrimMatMul},
81     {kGPUDevice, OpLevel_1, prim::kPrimReduceMean},
82     {kGPUDevice, OpLevel_0, prim::kPrimRelu},
83     {kGPUDevice, OpLevel_0, prim::kPrimReluGrad},
84     {kGPUDevice, OpLevel_0, prim::kPrimSigmoid},
85     {kGPUDevice, OpLevel_0, prim::kPrimSigmoidGrad},
86     {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogits},
87     {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
88     {kGPUDevice, OpLevel_0, prim::kPrimSlice},
89     {kGPUDevice, OpLevel_1, prim::kPrimSoftmax},
90     {kGPUDevice, OpLevel_1, prim::kPrimSoftmaxCrossEntropyWithLogits},
91     {kGPUDevice, OpLevel_0, prim::kPrimSquaredDifference},
92     {kGPUDevice, OpLevel_0, prim::kPrimSqueeze},
93     {kGPUDevice, OpLevel_0, prim::kPrimEqualCount},
94     {kGPUDevice, OpLevel_0, prim::kPrimSquareSumAll},
95     {kGPUDevice, OpLevel_0, prim::kPrimIdentityMath},
96     {kGPUDevice, OpLevel_0, prim::kPrimOnesLike},
97     {kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
98   };
99   const auto &flags = context::GraphKernelFlags::GetInstance();
100   std::vector<PrimitivePtr> expand_ops = GetValidOps(expand_ops_with_level, flags.fusion_ops_level);
101   OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops);
102   return expand_ops;
103 }
104 }  // namespace
105 
ExpandJsonInfo(const AnfNodePtr & node,nlohmann::json * kernel_json)106 bool PyExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
107   DumpOption dump_option;
108   dump_option.extract_opinfo_from_anfnode = true;
109   kernel::AkgKernelJsonGenerator json_generator(dump_option);
110   return json_generator.CollectJson(node, kernel_json);
111 }
112 
CreateExpandFuncGraph(const CNodePtr & node)113 FuncGraphPtr PyExpander::CreateExpandFuncGraph(const CNodePtr &node) {
114   nlohmann::json kernel_json;
115   if (!ExpandJsonInfo(node, &kernel_json)) {
116     MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
117     return nullptr;
118   }
119   auto node_desc_str = kernel_json.dump();
120 
121   // call graph kernel ops generator.
122   MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str;
123   auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str);
124   // parse result.
125   if (py::isinstance<py::none>(ret)) {
126     MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n"
127                   << node_desc_str;
128     return nullptr;
129   }
130   std::string kernel_desc_str = py::cast<std::string>(ret);
131   if (kernel_desc_str.empty()) {
132     return nullptr;
133   }
134   // decode json to func_graph.
135   return JsonDescToAnf(kernel_desc_str);
136 }
137 
CreateExpandFuncGraph(const CNodePtr & node)138 FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
139   auto expander_ptr = expanders::OpExpanderFactory::Instance().GetExpander(AnfAlgo::GetCNodeName(node));
140   if (expander_ptr == nullptr) {
141     return PyExpander::CreateExpandFuncGraph(node);
142   }
143   expanders::BaseInfoList inputs(node->size() - 1);
144   expanders::BaseInfoList outputs(AnfAlgo::GetOutputTensorNum(node));
145   for (size_t i = 0; i < inputs.size(); i++) {
146     auto shape = AnfAlgo::GetInputDeviceShape(node, i);
147     (void)std::transform(shape.begin(), shape.end(), std::back_inserter(inputs[i].shape), SizeToLong);
148     inputs[i].type = AnfAlgo::GetInputDeviceDataType(node, i);
149     inputs[i].format = AnfAlgo::GetInputFormat(node, i);
150   }
151   for (size_t i = 0; i < outputs.size(); i++) {
152     auto shape = AnfAlgo::GetOutputDeviceShape(node, i);
153     (void)std::transform(shape.begin(), shape.end(), std::back_inserter(outputs[i].shape), SizeToLong);
154     outputs[i].type = AnfAlgo::GetOutputDeviceDataType(node, i);
155     outputs[i].format = AnfAlgo::GetOutputFormat(node, i);
156   }
157   auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs();
158   auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
159   if (litegraph == nullptr) {
160     MS_LOG(INFO) << "undo expanding " << node->fullname_with_scope();
161     return nullptr;
162   }
163   return LiteGraph2AnfGraph(litegraph);
164 }
165 
CreateExpandGraphKernel(const FuncGraphPtr & new_func_graph,const CNodePtr & old_node)166 AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
167   auto func_graph = old_node->func_graph();
168   std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
169   AnfNodePtrList kernel_nodes;
170   AnfNodePtrList outputs;
171   EliminateRedundantParameters(new_func_graph, &inputs);
172   kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
173   kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
174   auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
175   SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs);
176   MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope()
177                 << " with: " << graph_kernel_node->fullname_with_scope();
178   return graph_kernel_node;
179 }
180 
Run(const AnfNodePtr & node)181 AnfNodePtr PyExpander::Run(const AnfNodePtr &node) {
182   auto cnode = node->cast<CNodePtr>();
183   MS_EXCEPTION_IF_NULL(cnode);
184   auto new_func_graph = CreateExpandFuncGraph(cnode);
185   if (new_func_graph == nullptr) {
186     return nullptr;
187   }
188   new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(cnode)));
189   auto graph_kernel_node = CreateExpandGraphKernel(new_func_graph, cnode);
190   if (AnfAlgo::GetOutputTensorNum(node) != AnfAlgo::GetOutputTensorNum(graph_kernel_node)) {
191     MS_LOG(ERROR) << "The output num of composite node (" << AnfAlgo::GetOutputTensorNum(graph_kernel_node)
192                   << ") does not match the original basic node (" << AnfAlgo::GetOutputTensorNum(node) << ")."
193                   << node->fullname_with_scope();
194     return nullptr;
195   }
196   return graph_kernel_node;
197 }
198 
GetExpander(const AnfNodePtr & node)199 ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
200   std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
201     {prim::kPrimDropout, std::make_shared<DropoutExpander>()},
202     {prim::kPrimAssignAdd, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
203     {prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
204     {prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambOptimizerInputIdx)},
205     {prim::kLambApplyWeightAssign, std::make_shared<OpUMonadExpander>(kLambWeightInputIdx)},
206     {prim::kPrimStandardNormal, std::make_shared<OpUMonadExpander>(kRandomInputIdx)},
207   };
208 
209   for (auto &e : expanders) {
210     if (IsPrimitiveCNode(node, e.first)) {
211       return e.second;
212     }
213   }
214   return std::make_shared<DefaultExpander>();
215 }
216 
DoExpand(const FuncGraphPtr & func_graph)217 bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
218   bool changed = false;
219   auto todos = TopoSort(func_graph->get_return());
220   std::reverse(todos.begin(), todos.end());
221   auto mng = func_graph->manager();
222   MS_EXCEPTION_IF_NULL(mng);
223   for (const auto &n : todos) {
224     auto node = n->cast<CNodePtr>();
225     if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) ||
226         !CanExpand(node)) {
227       continue;
228     }
229 
230     MS_LOG(DEBUG) << "Expanding node: " << node->fullname_with_scope();
231     auto new_node = GetExpander(node)->Run(node);
232     if (new_node == nullptr) {
233       MS_LOG(DEBUG) << "Skipped node: " << node->fullname_with_scope();
234       continue;
235     }
236     (void)mng->Replace(node, new_node);
237     changed = true;
238   }
239   return changed;
240 }
241 
CanExpand(const CNodePtr & node) const242 bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
243   bool has_complex = false;
244   auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
245   for (size_t i = 0; i < all_inputs_type.size(); ++i) {
246     if (all_inputs_type[i] == kNumberTypeComplex64) {
247       has_complex = true;
248       break;
249     }
250   }
251   return has_complex;
252 }
253 
GetExpander(const AnfNodePtr &)254 ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &) {
255   return std::make_shared<ComplexOpExpander>();
256 }
ExpandJsonInfo(const AnfNodePtr & node,nlohmann::json * kernel_json)257 bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
258   auto cnode = node->cast<CNodePtr>();
259   if (!PyExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
260   (*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
261   return true;
262 }
Run(const FuncGraphPtr & func_graph)263 bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
264   expand_ops_ = GetExpandOps();
265   return DoExpand(func_graph);
266 }
Run(const FuncGraphPtr & func_graph)267 bool GraphKernelComplexExpander::Run(const FuncGraphPtr &func_graph) { return DoExpand(func_graph); }
268 }  // namespace opt
269 }  // namespace mindspore
270