1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/adapter/expander.h"
18
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <utility>
25 #include "backend/common/graph_kernel/convert_input_and_attr.h"
26 #include "mindspore/core/ops/structure_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/random_ops.h"
29 #include "mindspore/core/ops/nn_optimizer_ops.h"
30 #include "mindspore/core/ops/nn_ops.h"
31 #include "mindspore/core/ops/math_ops.h"
32 #include "mindspore/core/ops/lite_ops.h"
33 #include "mindspore/core/ops/comparison_ops.h"
34 #include "mindspore/core/ops/array_ops.h"
35 #include "mindspore/core/ops/framework_ops.h"
36 #include "include/common/utils/python_adapter.h"
37 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
38 #include "backend/common/graph_kernel/core/split_umonad.h"
39 #include "backend/common/graph_kernel/substitute_dropout.h"
40 #include "backend/common/graph_kernel/graph_kernel_helper.h"
41 #include "backend/common/graph_kernel/graph_kernel_flags.h"
42 #include "backend/common/graph_kernel/adapter/callback_impl.h"
43 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
44 #include "backend/common/pass/inplace_assign_for_custom_op.h"
45 #include "kernel/common_utils.h"
46 #include "utils/ms_context.h"
47 #include "include/common/debug/anf_ir_dump.h"
48 #include "ir/func_graph_cloner.h"
49 #include "mindspore/core/ops/op_name.h"
50
51 namespace mindspore::graphkernel {
GetExpander(const AnfNodePtr & node,const ExpanderPtr & init)52 ExpanderPtr GetExpander(const AnfNodePtr &node, const ExpanderPtr &init) {
53 MS_EXCEPTION_IF_NULL(node);
54 MS_EXCEPTION_IF_NULL(init);
55 if (IsComplexOp(node)) {
56 return ComplexOpDecorator::Creator(init);
57 }
58
59 constexpr size_t kAssignInputIdx = 1;
60 constexpr size_t kLambOptimizerInputIdx = 12;
61 constexpr size_t kLambWeightInputIdx = 4;
62 constexpr size_t kRandomInputIdx = 1;
63 constexpr size_t kAdamInputIdx = 10;
64 constexpr size_t kAdamWeightDecayInputIdx = 9;
65 constexpr size_t kApplyMomentumInputIdx = 1;
66 std::map<std::string, ExpanderCreatorFuncList> creators = {
67 {prim::kPrimAssignAdd->name(), {OpUMonadExpanderDeco::GetCreator(kAssignInputIdx)}},
68 {prim::kPrimAdamApplyOneWithDecayAssign->name(), {OpUMonadExpanderDeco::GetCreator(kIndex2)}},
69 {prim::kLambApplyOptimizerAssign->name(), {OpUMonadExpanderDeco::GetCreator(kLambOptimizerInputIdx)}},
70 {prim::kLambApplyWeightAssign->name(), {OpUMonadExpanderDeco::GetCreator(kLambWeightInputIdx)}},
71 {prim::kPrimStandardNormal->name(), {OpUMonadExpanderDeco::GetCreator(kRandomInputIdx)}},
72 {prim::kPrimAdam->name(), {OpUMonadExpanderDeco::GetCreator(kAdamInputIdx)}},
73 {prim::kPrimAdamWeightDecay->name(), {OpUMonadExpanderDeco::GetCreator(kAdamWeightDecayInputIdx)}},
74 {prim::kPrimApplyMomentum->name(), {OpUMonadExpanderDeco::GetCreator(kApplyMomentumInputIdx)}},
75 {prim::kPrimDropout->name(), {DropoutExpanderDeco::Creator}},
76 {prim::kPrimArgMaxWithValue->name(), {ArgWithValueDeco::Creator}},
77 {prim::kPrimArgMinWithValue->name(), {ArgWithValueDeco::Creator}},
78 {prim::kPrimSolveTriangular->name(), {ProcessCustomOpDeco::Creator}},
79 {prim::kPrimLU->name(), {ProcessCustomOpDeco::Creator}},
80 {prim::kPrimExpandDims->name(), {DependValueDeco::GetCreator({1})}},
81 {prim::kPrimReduceMean->name(), {DependValueDeco::GetCreator({1})}},
82 {prim::kPrimTile->name(), {DependValueDeco::GetCreator({1})}},
83 {prim::kPrimSlice->name(), {DependValueDeco::GetCreator({1, 2})}},
84 {prim::kPrimGather->name(), {DependValueDeco::GetCreator({2})}},
85 {prim::kPrimAddN->name(), {UnfoldMakeTupleDeco::Creator}}};
86
87 ExpanderPtr expander = init;
88 const auto iter = creators.find(GetCNodePrimitive(node)->name());
89 if (iter != creators.end()) {
90 expander = WrapExpander(expander, iter->second);
91 }
92 if (common::AnfAlgo::IsDynamicShape(node)) {
93 MS_LOG(INFO) << "try expander dynamic shape node.";
94 expander = SetDynamicShapeAttrDeco::Creator(expander);
95 }
96 return expander;
97 }
98
GetExpander(const AnfNodePtr & node,bool abstract)99 ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
100 ExpanderPtr expander =
101 abstract
102 ? std::make_shared<LitegraphExpander>(
103 std::static_pointer_cast<Callback>(std::make_shared<CallbackImplWithInferShape>()))
104 : std::make_shared<LitegraphExpander>(std::static_pointer_cast<Callback>(std::make_shared<CallbackImpl>()));
105 return GetExpander(node, expander);
106 }
107
CanExpandFallback(const AnfNodePtr & node)108 bool CanExpandFallback(const AnfNodePtr &node) {
109 if (!node->isa<CNode>()) {
110 return false;
111 }
112 if (common::GetEnv("MS_DEV_EXPANDER_FALLBACK") == "off") {
113 return false;
114 }
115 // Operators with 'batch_rank' attribute, which only appears in the vmap scenario, are not supported currently.
116 if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, node->cast<CNodePtr>())) {
117 return false;
118 }
119 static const std::vector<OpWithLevel> expander_fallback_ops_with_level = {
120 {kAllTarget, OpLevel_0, prim::kPrimEqualCount},
121 {kAllTarget, OpLevel_0, prim::kPrimSoftsign},
122 {kAllTarget, OpLevel_0, prim::kPrimSquare},
123 {kAllTarget, OpLevel_0, prim::kPrimBiasAdd},
124 {kAllTarget, OpLevel_0, prim::kPrimReLU},
125 {kAllTarget, OpLevel_0, prim::kPrimRelu},
126 {kAllTarget, OpLevel_0, prim::kPrimSigmoid},
127 {kAllTarget, OpLevel_0, prim::kPrimBiasAdd},
128 {kAllTarget, OpLevel_0, prim::kPrimReLU},
129 {kAllTarget, OpLevel_0, prim::kPrimSoftplus},
130 {kAllTarget, OpLevel_0, prim::kPrimSoftplusGrad},
131 {kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
132 {kAllTarget, OpLevel_0, prim::kLambApplyOptimizerAssign},
133 {kAllTarget, OpLevel_0, prim::kLambApplyWeightAssign},
134 {kAllTarget, OpLevel_0, prim::kPrimAdamWeightDecay},
135 {kAllTarget, OpLevel_0, prim::kPrimStandardNormal},
136 {kAllTarget, OpLevel_0, prim::kPrimAdam},
137 // some ops including custom op are only used expand fallbak on Ascend.
138 {kAscendDevice, OpLevel_0, prim::kPrimSolveTriangular},
139 {kAscendDevice, OpLevel_0, prim::kPrimLU},
140 // disabled
141 {kAllTarget, OpLevel_1, prim::kPrimAddN},
142 {kAllTarget, OpLevel_1, prim::kPrimErfc},
143 {kAllTarget, OpLevel_1, prim::kPrimExpandDims},
144 {kAllTarget, OpLevel_1, prim::kPrimGeLU},
145 {kAllTarget, OpLevel_1, prim::kPrimGeLUGrad},
146 {kAllTarget, OpLevel_1, prim::kPrimSqrtGrad},
147 {kAllTarget, OpLevel_1, prim::kPrimTile},
148 {kAllTarget, OpLevel_1, prim::kPrimClipByNormNoDivSum},
149 {kAllTarget, OpLevel_1, prim::kSoftmaxGradExt},
150 {kAllTarget, OpLevel_1, prim::kFusedMulAdd},
151 {kAllTarget, OpLevel_1, prim::kPrimBatchMatMul},
152 {kAllTarget, OpLevel_1, prim::kPrimBiasAddGrad},
153 {kAllTarget, OpLevel_1, prim::kPrimDropout},
154 {kAllTarget, OpLevel_1, prim::kPrimDropoutGrad},
155 {kAllTarget, OpLevel_1, prim::kPrimMaximumGrad},
156 {kAllTarget, OpLevel_1, prim::kPrimMinimumGrad},
157 {kAllTarget, OpLevel_1, prim::kPrimLayerNorm},
158 {kAllTarget, OpLevel_1, prim::kPrimLayerNormGrad},
159 {kAllTarget, OpLevel_1, prim::kPrimLogSoftmax},
160 {kAllTarget, OpLevel_1, prim::kPrimLogSoftmaxV2},
161 {kAllTarget, OpLevel_1, prim::kPrimLogSoftmaxGrad},
162 {kAllTarget, OpLevel_1, prim::kPrimMatMul},
163 {kAllTarget, OpLevel_1, prim::kPrimReduceMean},
164 {kAllTarget, OpLevel_1, prim::kPrimReluGrad},
165 {kAllTarget, OpLevel_1, prim::kPrimSigmoidGrad},
166 {kAllTarget, OpLevel_1, prim::kPrimSigmoidCrossEntropyWithLogits},
167 {kAllTarget, OpLevel_1, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
168 {kAllTarget, OpLevel_1, prim::kPrimSlice},
169 {kAllTarget, OpLevel_1, prim::kPrimSoftmax},
170 {kAllTarget, OpLevel_1, prim::kPrimSoftmaxV2},
171 {kAllTarget, OpLevel_1, prim::kPrimSoftmaxCrossEntropyWithLogits},
172 {kAllTarget, OpLevel_1, prim::kPrimSquaredDifference},
173 {kAllTarget, OpLevel_1, prim::kPrimSqueeze},
174 {kAllTarget, OpLevel_1, prim::kPrimSquareSumAll},
175 {kAllTarget, OpLevel_1, prim::kPrimIdentityMath},
176 {kAllTarget, OpLevel_1, prim::kPrimOnesLike},
177 {kAllTarget, OpLevel_1, prim::kPrimBiasAddGrad},
178 {kAllTarget, OpLevel_1, prim::kPrimMaximumGrad},
179 {kAllTarget, OpLevel_1, prim::kPrimMinimumGrad},
180 {kAllTarget, OpLevel_1, prim::kPrimTanhGrad},
181 };
182 unsigned int op_level = (common::GetEnv("MS_DEV_EXPANDER_FALLBACK") == "1") ? 1 : 0;
183 auto ops = GkUtils::GetValidOps(expander_fallback_ops_with_level, op_level, {}, {}, {});
184 return std::any_of(ops.begin(), ops.end(),
185 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
186 }
187
Run(const AnfNodePtr & node)188 AnfNodePtr ProcessCustomOpDeco::Run(const AnfNodePtr &node) {
189 if (node == nullptr) {
190 return nullptr;
191 }
192 auto new_node = decorated_->Run(node);
193 auto graph = GetCNodeFuncGraph(new_node);
194 if (graph == nullptr) {
195 return nullptr;
196 }
197 auto optimizer = std::make_shared<opt::GraphOptimizer>();
198 auto pm = std::make_shared<opt::PassManager>();
199 pm->AddPass(std::make_shared<opt::InplaceAssignForCustomOp>());
200 optimizer->AddPassManager(pm);
201 (void)optimizer->Optimize(graph);
202 return new_node;
203 }
204
TryExpandCNode(const AnfNodePtr & node,const std::function<bool (const CNodePtr &)> & func)205 AnfNodePtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const CNodePtr &)> &func) {
206 if (!CanExpandFallback(node)) {
207 return nullptr;
208 }
209 auto cnode = node->cast<CNodePtr>();
210 auto expander = GetExpander(node);
211 auto res = expander->Run(node);
212 auto expand_fg = GetCNodeFuncGraph(res);
213 if (expand_fg == nullptr) {
214 return nullptr;
215 }
216 // For Ascend, the selectkernel function may check and change the input Parameter of kernel,
217 // so we replace the inner Parameter to outer Parameter
218 bool need_replace_parameter = Callback::Instance()->GetTargetFromContext() == kAscendDevice;
219 std::map<AnfNodePtr, AnfNodePtr> param_map;
220 if (need_replace_parameter) {
221 auto ¶ms = expand_fg->parameters();
222 for (size_t i = 0; i < params.size(); i++) {
223 if (cnode->input(i + 1)->isa<Parameter>()) {
224 param_map[params[i]] = cnode->input(i + 1);
225 }
226 }
227 need_replace_parameter = !param_map.empty();
228 }
229
230 auto todos = TopoSort(expand_fg->output());
231 for (const auto &inner_node : todos) {
232 if (!AnfUtils::IsRealCNodeKernel(inner_node)) {
233 continue;
234 }
235 try {
236 MS_LOG_TRY_CATCH_SCOPE;
237 bool suc = false;
238 if (OpDefAdapter::NeedConvertGK2FE(inner_node)) {
239 (void)ConvertGraphKernelToFrontEnd::Process(inner_node);
240 }
241 auto inner_cnode = inner_node->cast<CNodePtr>();
242 if (need_replace_parameter) {
243 std::vector<std::pair<size_t, AnfNodePtr>> ori_input;
244 for (size_t i = 1; i < inner_cnode->size(); i++) {
245 auto iter = param_map.find(inner_cnode->input(i));
246 if (iter != param_map.end()) {
247 MS_LOG(DEBUG) << "Replace " << inner_cnode->input(i)->DebugString() << " by "
248 << iter->second->DebugString();
249 (void)ori_input.emplace_back(i, inner_cnode->input(i));
250 inner_cnode->set_input(i, iter->second);
251 }
252 }
253 suc = func(inner_cnode);
254 // recover the origin inputs
255 for (auto &ori : ori_input) {
256 inner_cnode->set_input(ori.first, ori.second);
257 }
258 } else {
259 suc = func(inner_cnode);
260 }
261 if (!suc) {
262 MS_LOG(INFO) << "ExpanderFallback: select kernel [" << inner_node->fullname_with_scope() << "] failed.";
263 res = nullptr;
264 break;
265 }
266 } catch (std::exception &e) {
267 MS_LOG(WARNING) << "ExpanderFallback: error in select kernel for [" << inner_node->fullname_with_scope()
268 << "], msg: " << e.what();
269 res = nullptr;
270 break;
271 }
272 }
273 #ifdef ENABLE_DUMP_IR
274 auto context = MsContext::GetInstance();
275 MS_EXCEPTION_IF_NULL(context);
276 if (context->CanDump(kAdvanced)) {
277 DumpIR("verbose_ir_files/expand_" + GetCNodeFuncName(node->cast<CNodePtr>()) + ".ir", expand_fg);
278 }
279 #endif
280 return res;
281 }
282
SetDynamicShapeAttrToCNode(const CNodePtr & cnode)283 void SetDynamicShapeAttrToCNode(const CNodePtr &cnode) {
284 auto in_dynamic = common::AnfAlgo::IsNodeInputDynamicShape(cnode);
285 auto out_dynamic = common::AnfAlgo::IsNodeOutputDynamicShape(cnode);
286 if (in_dynamic && !common::AnfAlgo::HasNodeAttr(kAttrInputIsDynamicShape, cnode)) {
287 common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode);
288 }
289 if (out_dynamic && !common::AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicShape, cnode)) {
290 common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
291 }
292 }
293
SetDynamicShapeAttr(const FuncGraphPtr & graph)294 void SetDynamicShapeAttr(const FuncGraphPtr &graph) {
295 auto todos = TopoSort(graph->get_return());
296 for (const auto &node : todos) {
297 if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
298 continue;
299 }
300 auto cnode = dyn_cast<CNode>(node);
301 SetDynamicShapeAttrToCNode(cnode);
302 }
303 }
304
Run(const AnfNodePtr & node)305 AnfNodePtr SetDynamicShapeAttrDeco::Run(const AnfNodePtr &node) {
306 auto new_node = decorated_->Run(node);
307 if (new_node == nullptr) {
308 return nullptr;
309 }
310 auto new_cnode = dyn_cast<CNode>(new_node);
311 auto expand_fg = GetCNodeFuncGraph(new_cnode);
312 SetDynamicShapeAttr(expand_fg);
313 new_cnode->set_input(0, NewValueNode(expand_fg));
314 return new_cnode;
315 }
316
Run(const AnfNodePtr & node)317 AnfNodePtr ComplexOpDecorator::Run(const AnfNodePtr &node) {
318 auto cnode = QuickCloneCNode(node);
319 auto prim = GetCNodePrimitive(cnode);
320 MS_EXCEPTION_IF_NULL(prim);
321 cnode->set_input(0, NewValueNode(std::make_shared<Primitive>("C" + prim->name(), prim->attrs())));
322 return decorated_->Run(cnode);
323 }
324
325 // Used for ArgMaxWithValue(ArgMinWithValue) which output is tuple(index,value)
326 // Currently only expand it when output[1] has users and output[0] has no users
327 // In this case, ArgMaxWithValue(ArgMinWithValue) can be converted to ReduceMax(ReduceMin)
328 // If output[0] has users, expanding is not allowed
Run(const AnfNodePtr & node)329 AnfNodePtr ArgWithValueDeco::Run(const AnfNodePtr &node) {
330 auto mng = GkUtils::GetFuncGraphManager(node->func_graph());
331 bool res = false;
332 if (auto iter = mng->node_users().find(node); iter != mng->node_users().end()) {
333 auto output_info_list = iter->second;
334 res = std::all_of(output_info_list.begin(), output_info_list.end(), [](const std::pair<AnfNodePtr, int> &info) {
335 if (IsPrimitiveCNode(info.first, prim::kPrimTupleGetItem)) {
336 const auto &cnode = info.first->cast<CNodePtr>();
337 auto value_ptr = GetValueNode(cnode->input(kInputNodeOutputIndexInTupleGetItem));
338 MS_EXCEPTION_IF_NULL(value_ptr);
339 return GetValue<int64_t>(value_ptr) == 1;
340 }
341 return false;
342 });
343 }
344 return res ? decorated_->Run(node) : nullptr;
345 }
346
Run(const AnfNodePtr & node)347 AnfNodePtr UnfoldMakeTupleDeco::Run(const AnfNodePtr &node) {
348 auto cnode = node->cast<CNodePtr>();
349 MS_EXCEPTION_IF_NULL(cnode);
350 if (cnode->size() == kIndex2 && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) {
351 auto make_tupe_cnode = cnode->input(1)->cast<CNodePtr>();
352 MS_EXCEPTION_IF_NULL(make_tupe_cnode);
353 std::vector<AnfNodePtr> new_inputs;
354 new_inputs.push_back(cnode->input(0));
355 for (size_t i = 1; i < make_tupe_cnode->size(); ++i) {
356 new_inputs.push_back(make_tupe_cnode->input(i));
357 }
358 cnode = QuickCloneCNode(cnode);
359 cnode->set_inputs(new_inputs);
360 }
361 return decorated_->Run(cnode);
362 }
363
InlineExpandFuncGraph(const AnfNodePtr & expanding_node,const FuncGraphPtr & expanded_graph)364 void InlineExpandFuncGraph(const AnfNodePtr &expanding_node, const FuncGraphPtr &expanded_graph) {
365 auto main_graph = expanding_node->func_graph();
366 auto mng = main_graph->manager();
367 if (mng == nullptr) {
368 mng = Manage(main_graph, true);
369 main_graph->set_manager(mng);
370 }
371 auto cnode = expanding_node->cast<CNodePtr>();
372 MS_EXCEPTION_IF_NULL(cnode);
373 AnfNodePtrList inp(cnode->inputs().begin() + 1, cnode->inputs().end());
374 auto out = InlineClone(expanded_graph, main_graph, inp, cnode);
375 (void)mng->Replace(expanding_node, out);
376 }
377
IsComplexOp(const AnfNodePtr & node)378 bool IsComplexOp(const AnfNodePtr &node) {
379 auto cnode = node->cast<CNodePtr>();
380 MS_EXCEPTION_IF_NULL(cnode);
381 for (size_t i = 1; i < cnode->size(); i++) {
382 auto input = cnode->input(i);
383 TypePtr input_type = input->Type();
384 if (input_type == nullptr || !input_type->isa<TensorType>()) {
385 return false;
386 }
387 input_type = input_type->cast<TensorTypePtr>()->element();
388 if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
389 return true;
390 }
391 }
392 return false;
393 }
394 } // namespace mindspore::graphkernel
395