1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/graph_kernel/graph_kernel_expander.h"
18
19 #include <string>
20 #include <set>
21 #include <utility>
22 #include <vector>
23 #include <tuple>
24 #include <algorithm>
25
26 #include "utils/context/graph_kernel_flags.h"
27 #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
28 #include "backend/kernel_compiler/common_utils.h"
29 #include "backend/kernel_compiler/kernel_build_info.h"
30 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
31 #include "backend/optimizer/graph_kernel/split_umonad.h"
32 #include "backend/optimizer/graph_kernel/substitute_dropout.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "mindspore/core/ir/graph_utils.h"
35 #include "pipeline/jit/parse/python_adapter.h"
36 #include "pybind_api/ir/primitive_py.h"
37 #include "runtime/device/kernel_info.h"
38 #include "vm/segment_runner.h"
39 #include "backend/optimizer/graph_kernel/expanders/expander_factory.h"
40
41 namespace mindspore {
42 namespace opt {
43 namespace {
44 using context::OpLevel_0;
45 using context::OpLevel_1;
46 constexpr size_t kAssignInputIdx = 1;
47 constexpr size_t kLambOptimizerInputIdx = 12;
48 constexpr size_t kLambWeightInputIdx = 4;
49 constexpr size_t kRandomInputIdx = 1;
50
GetExpandOps()51 std::vector<PrimitivePtr> GetExpandOps() {
52 std::vector<std::tuple<std::string, unsigned int, PrimitivePtr>> expand_ops_with_level = {
53 {kAllTarget, OpLevel_0, prim::kPrimAddN},
54 {kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
55 {kAllTarget, OpLevel_0, prim::kPrimErfc},
56 {kAllTarget, OpLevel_1, prim::kPrimExpandDims},
57 {kAllTarget, OpLevel_0, prim::kPrimGeLU},
58 {kAllTarget, OpLevel_0, prim::kPrimGeLUGrad},
59 {kAllTarget, OpLevel_0, prim::kPrimSquare},
60 {kAllTarget, OpLevel_0, prim::kPrimTile},
61 {kAscendDevice, OpLevel_0, prim::kLambApplyOptimizerAssign},
62 {kAscendDevice, OpLevel_0, prim::kLambApplyWeightAssign},
63 {kAscendDevice, OpLevel_0, prim::kPrimClipByNormNoDivSum},
64 {kAscendDevice, OpLevel_0, prim::kPrimSqrtGrad},
65 {kAscendDevice, OpLevel_1, prim::kSoftmaxGradExt},
66 {kAscendDevice, OpLevel_0, prim::kFusedMulAdd},
67 {kGPUDevice, OpLevel_1, prim::kPrimBatchMatMul},
68 {kGPUDevice, OpLevel_0, prim::kPrimBiasAdd},
69 {kGPUDevice, OpLevel_1, prim::kPrimBiasAddGrad},
70 {kGPUDevice, OpLevel_0, prim::kPrimDropout},
71 {kGPUDevice, OpLevel_0, prim::kPrimDropoutGrad},
72 {kGPUDevice, OpLevel_0, prim::kPrimFusedAdam},
73 {kGPUDevice, OpLevel_0, prim::kPrimFusedAdamWeightDecay},
74 {kGPUDevice, OpLevel_1, prim::kPrimMaximumGrad},
75 {kGPUDevice, OpLevel_1, prim::kPrimMinimumGrad},
76 {kGPUDevice, OpLevel_1, prim::kPrimLayerNorm},
77 {kGPUDevice, OpLevel_1, prim::kPrimLayerNormGrad},
78 {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmax},
79 {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmaxGrad},
80 {kGPUDevice, OpLevel_1, prim::kPrimMatMul},
81 {kGPUDevice, OpLevel_1, prim::kPrimReduceMean},
82 {kGPUDevice, OpLevel_0, prim::kPrimRelu},
83 {kGPUDevice, OpLevel_0, prim::kPrimReluGrad},
84 {kGPUDevice, OpLevel_0, prim::kPrimSigmoid},
85 {kGPUDevice, OpLevel_0, prim::kPrimSigmoidGrad},
86 {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogits},
87 {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
88 {kGPUDevice, OpLevel_0, prim::kPrimSlice},
89 {kGPUDevice, OpLevel_1, prim::kPrimSoftmax},
90 {kGPUDevice, OpLevel_1, prim::kPrimSoftmaxCrossEntropyWithLogits},
91 {kGPUDevice, OpLevel_0, prim::kPrimSquaredDifference},
92 {kGPUDevice, OpLevel_0, prim::kPrimSqueeze},
93 {kGPUDevice, OpLevel_0, prim::kPrimEqualCount},
94 {kGPUDevice, OpLevel_0, prim::kPrimSquareSumAll},
95 {kGPUDevice, OpLevel_0, prim::kPrimIdentityMath},
96 {kGPUDevice, OpLevel_0, prim::kPrimOnesLike},
97 {kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
98 };
99 const auto &flags = context::GraphKernelFlags::GetInstance();
100 std::vector<PrimitivePtr> expand_ops = GetValidOps(expand_ops_with_level, flags.fusion_ops_level);
101 OpListFilter(&expand_ops, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops);
102 return expand_ops;
103 }
104 } // namespace
105
ExpandJsonInfo(const AnfNodePtr & node,nlohmann::json * kernel_json)106 bool PyExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
107 DumpOption dump_option;
108 dump_option.extract_opinfo_from_anfnode = true;
109 kernel::AkgKernelJsonGenerator json_generator(dump_option);
110 return json_generator.CollectJson(node, kernel_json);
111 }
112
CreateExpandFuncGraph(const CNodePtr & node)113 FuncGraphPtr PyExpander::CreateExpandFuncGraph(const CNodePtr &node) {
114 nlohmann::json kernel_json;
115 if (!ExpandJsonInfo(node, &kernel_json)) {
116 MS_LOG(ERROR) << "Expand json info to: " << node->DebugString(2) << " failed, ori_json:\n" << kernel_json.dump();
117 return nullptr;
118 }
119 auto node_desc_str = kernel_json.dump();
120
121 // call graph kernel ops generator.
122 MS_LOG(DEBUG) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] with input json:\n" << node_desc_str;
123 auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGetGraphKernelOpExpander, node_desc_str);
124 // parse result.
125 if (py::isinstance<py::none>(ret)) {
126 MS_LOG(ERROR) << "CallPyFn: [" << kGetGraphKernelOpExpander << "] return invalid result, input json:\n"
127 << node_desc_str;
128 return nullptr;
129 }
130 std::string kernel_desc_str = py::cast<std::string>(ret);
131 if (kernel_desc_str.empty()) {
132 return nullptr;
133 }
134 // decode json to func_graph.
135 return JsonDescToAnf(kernel_desc_str);
136 }
137
CreateExpandFuncGraph(const CNodePtr & node)138 FuncGraphPtr DefaultExpander::CreateExpandFuncGraph(const CNodePtr &node) {
139 auto expander_ptr = expanders::OpExpanderFactory::Instance().GetExpander(AnfAlgo::GetCNodeName(node));
140 if (expander_ptr == nullptr) {
141 return PyExpander::CreateExpandFuncGraph(node);
142 }
143 expanders::BaseInfoList inputs(node->size() - 1);
144 expanders::BaseInfoList outputs(AnfAlgo::GetOutputTensorNum(node));
145 for (size_t i = 0; i < inputs.size(); i++) {
146 auto shape = AnfAlgo::GetInputDeviceShape(node, i);
147 (void)std::transform(shape.begin(), shape.end(), std::back_inserter(inputs[i].shape), SizeToLong);
148 inputs[i].type = AnfAlgo::GetInputDeviceDataType(node, i);
149 inputs[i].format = AnfAlgo::GetInputFormat(node, i);
150 }
151 for (size_t i = 0; i < outputs.size(); i++) {
152 auto shape = AnfAlgo::GetOutputDeviceShape(node, i);
153 (void)std::transform(shape.begin(), shape.end(), std::back_inserter(outputs[i].shape), SizeToLong);
154 outputs[i].type = AnfAlgo::GetOutputDeviceDataType(node, i);
155 outputs[i].format = AnfAlgo::GetOutputFormat(node, i);
156 }
157 auto &attrs = AnfAlgo::GetCNodePrimitive(node)->attrs();
158 auto litegraph = expander_ptr->Run(inputs, outputs, attrs, kernel::GetStrProcessorFromContext());
159 if (litegraph == nullptr) {
160 MS_LOG(INFO) << "undo expanding " << node->fullname_with_scope();
161 return nullptr;
162 }
163 return LiteGraph2AnfGraph(litegraph);
164 }
165
CreateExpandGraphKernel(const FuncGraphPtr & new_func_graph,const CNodePtr & old_node)166 AnfNodePtr PyExpander::CreateExpandGraphKernel(const FuncGraphPtr &new_func_graph, const CNodePtr &old_node) {
167 auto func_graph = old_node->func_graph();
168 std::vector<AnfNodePtr> inputs(old_node->inputs().begin() + 1, old_node->inputs().end());
169 AnfNodePtrList kernel_nodes;
170 AnfNodePtrList outputs;
171 EliminateRedundantParameters(new_func_graph, &inputs);
172 kernel::GetValidKernelNodes(new_func_graph, &kernel_nodes);
173 kernel::GetFuncGraphOutputNodes(new_func_graph, &outputs);
174 auto graph_kernel_node = CreateNewFuseCNode(func_graph, new_func_graph, inputs, outputs);
175 SetNewKernelInfo(graph_kernel_node, new_func_graph, inputs, outputs);
176 MS_LOG(DEBUG) << "Expand node: " << old_node->fullname_with_scope()
177 << " with: " << graph_kernel_node->fullname_with_scope();
178 return graph_kernel_node;
179 }
180
Run(const AnfNodePtr & node)181 AnfNodePtr PyExpander::Run(const AnfNodePtr &node) {
182 auto cnode = node->cast<CNodePtr>();
183 MS_EXCEPTION_IF_NULL(cnode);
184 auto new_func_graph = CreateExpandFuncGraph(cnode);
185 if (new_func_graph == nullptr) {
186 return nullptr;
187 }
188 new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfAlgo::GetCNodeName(cnode)));
189 auto graph_kernel_node = CreateExpandGraphKernel(new_func_graph, cnode);
190 if (AnfAlgo::GetOutputTensorNum(node) != AnfAlgo::GetOutputTensorNum(graph_kernel_node)) {
191 MS_LOG(ERROR) << "The output num of composite node (" << AnfAlgo::GetOutputTensorNum(graph_kernel_node)
192 << ") does not match the original basic node (" << AnfAlgo::GetOutputTensorNum(node) << ")."
193 << node->fullname_with_scope();
194 return nullptr;
195 }
196 return graph_kernel_node;
197 }
198
GetExpander(const AnfNodePtr & node)199 ExpanderPtr GraphKernelExpander::GetExpander(const AnfNodePtr &node) {
200 std::vector<std::pair<PrimitivePtr, ExpanderPtr>> expanders = {
201 {prim::kPrimDropout, std::make_shared<DropoutExpander>()},
202 {prim::kPrimAssignAdd, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
203 {prim::kPrimAssignSub, std::make_shared<OpUMonadExpander>(kAssignInputIdx)},
204 {prim::kLambApplyOptimizerAssign, std::make_shared<OpUMonadExpander>(kLambOptimizerInputIdx)},
205 {prim::kLambApplyWeightAssign, std::make_shared<OpUMonadExpander>(kLambWeightInputIdx)},
206 {prim::kPrimStandardNormal, std::make_shared<OpUMonadExpander>(kRandomInputIdx)},
207 };
208
209 for (auto &e : expanders) {
210 if (IsPrimitiveCNode(node, e.first)) {
211 return e.second;
212 }
213 }
214 return std::make_shared<DefaultExpander>();
215 }
216
DoExpand(const FuncGraphPtr & func_graph)217 bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
218 bool changed = false;
219 auto todos = TopoSort(func_graph->get_return());
220 std::reverse(todos.begin(), todos.end());
221 auto mng = func_graph->manager();
222 MS_EXCEPTION_IF_NULL(mng);
223 for (const auto &n : todos) {
224 auto node = n->cast<CNodePtr>();
225 if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) ||
226 !CanExpand(node)) {
227 continue;
228 }
229
230 MS_LOG(DEBUG) << "Expanding node: " << node->fullname_with_scope();
231 auto new_node = GetExpander(node)->Run(node);
232 if (new_node == nullptr) {
233 MS_LOG(DEBUG) << "Skipped node: " << node->fullname_with_scope();
234 continue;
235 }
236 (void)mng->Replace(node, new_node);
237 changed = true;
238 }
239 return changed;
240 }
241
CanExpand(const CNodePtr & node) const242 bool GraphKernelComplexExpander::CanExpand(const CNodePtr &node) const {
243 bool has_complex = false;
244 auto all_inputs_type = AnfAlgo::GetAllInputDeviceTypes(node);
245 for (size_t i = 0; i < all_inputs_type.size(); ++i) {
246 if (all_inputs_type[i] == kNumberTypeComplex64) {
247 has_complex = true;
248 break;
249 }
250 }
251 return has_complex;
252 }
253
GetExpander(const AnfNodePtr &)254 ExpanderPtr GraphKernelComplexExpander::GetExpander(const AnfNodePtr &) {
255 return std::make_shared<ComplexOpExpander>();
256 }
ExpandJsonInfo(const AnfNodePtr & node,nlohmann::json * kernel_json)257 bool ComplexOpExpander::ExpandJsonInfo(const AnfNodePtr &node, nlohmann::json *kernel_json) {
258 auto cnode = node->cast<CNodePtr>();
259 if (!PyExpander::ExpandJsonInfo(cnode, kernel_json)) return false;
260 (*kernel_json)["name"] = std::string("C") + AnfAlgo::GetCNodeName(cnode);
261 return true;
262 }
Run(const FuncGraphPtr & func_graph)263 bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) {
264 expand_ops_ = GetExpandOps();
265 return DoExpand(func_graph);
266 }
Run(const FuncGraphPtr & func_graph)267 bool GraphKernelComplexExpander::Run(const FuncGraphPtr &func_graph) { return DoExpand(func_graph); }
268 } // namespace opt
269 } // namespace mindspore
270