• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/adapter/expander.h"
18 
19 #include <map>
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <utility>
25 #include "backend/common/graph_kernel/convert_input_and_attr.h"
26 #include "mindspore/core/ops/structure_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/random_ops.h"
29 #include "mindspore/core/ops/nn_optimizer_ops.h"
30 #include "mindspore/core/ops/nn_ops.h"
31 #include "mindspore/core/ops/math_ops.h"
32 #include "mindspore/core/ops/lite_ops.h"
33 #include "mindspore/core/ops/comparison_ops.h"
34 #include "mindspore/core/ops/array_ops.h"
35 #include "mindspore/core/ops/framework_ops.h"
36 #include "include/common/utils/python_adapter.h"
37 #include "kernel/graph_kernel/graph_kernel_json_generator.h"
38 #include "backend/common/graph_kernel/core/split_umonad.h"
39 #include "backend/common/graph_kernel/substitute_dropout.h"
40 #include "backend/common/graph_kernel/graph_kernel_helper.h"
41 #include "backend/common/graph_kernel/graph_kernel_flags.h"
42 #include "backend/common/graph_kernel/adapter/callback_impl.h"
43 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
44 #include "backend/common/pass/inplace_assign_for_custom_op.h"
45 #include "kernel/common_utils.h"
46 #include "utils/ms_context.h"
47 #include "include/common/debug/anf_ir_dump.h"
48 #include "ir/func_graph_cloner.h"
49 #include "mindspore/core/ops/op_name.h"
50 
51 namespace mindspore::graphkernel {
GetExpander(const AnfNodePtr & node,const ExpanderPtr & init)52 ExpanderPtr GetExpander(const AnfNodePtr &node, const ExpanderPtr &init) {
53   MS_EXCEPTION_IF_NULL(node);
54   MS_EXCEPTION_IF_NULL(init);
55   if (IsComplexOp(node)) {
56     return ComplexOpDecorator::Creator(init);
57   }
58 
59   constexpr size_t kAssignInputIdx = 1;
60   constexpr size_t kLambOptimizerInputIdx = 12;
61   constexpr size_t kLambWeightInputIdx = 4;
62   constexpr size_t kRandomInputIdx = 1;
63   constexpr size_t kAdamInputIdx = 10;
64   constexpr size_t kAdamWeightDecayInputIdx = 9;
65   constexpr size_t kApplyMomentumInputIdx = 1;
66   std::map<std::string, ExpanderCreatorFuncList> creators = {
67     {prim::kPrimAssignAdd->name(), {OpUMonadExpanderDeco::GetCreator(kAssignInputIdx)}},
68     {prim::kPrimAdamApplyOneWithDecayAssign->name(), {OpUMonadExpanderDeco::GetCreator(kIndex2)}},
69     {prim::kLambApplyOptimizerAssign->name(), {OpUMonadExpanderDeco::GetCreator(kLambOptimizerInputIdx)}},
70     {prim::kLambApplyWeightAssign->name(), {OpUMonadExpanderDeco::GetCreator(kLambWeightInputIdx)}},
71     {prim::kPrimStandardNormal->name(), {OpUMonadExpanderDeco::GetCreator(kRandomInputIdx)}},
72     {prim::kPrimAdam->name(), {OpUMonadExpanderDeco::GetCreator(kAdamInputIdx)}},
73     {prim::kPrimAdamWeightDecay->name(), {OpUMonadExpanderDeco::GetCreator(kAdamWeightDecayInputIdx)}},
74     {prim::kPrimApplyMomentum->name(), {OpUMonadExpanderDeco::GetCreator(kApplyMomentumInputIdx)}},
75     {prim::kPrimDropout->name(), {DropoutExpanderDeco::Creator}},
76     {prim::kPrimArgMaxWithValue->name(), {ArgWithValueDeco::Creator}},
77     {prim::kPrimArgMinWithValue->name(), {ArgWithValueDeco::Creator}},
78     {prim::kPrimSolveTriangular->name(), {ProcessCustomOpDeco::Creator}},
79     {prim::kPrimLU->name(), {ProcessCustomOpDeco::Creator}},
80     {prim::kPrimExpandDims->name(), {DependValueDeco::GetCreator({1})}},
81     {prim::kPrimReduceMean->name(), {DependValueDeco::GetCreator({1})}},
82     {prim::kPrimTile->name(), {DependValueDeco::GetCreator({1})}},
83     {prim::kPrimSlice->name(), {DependValueDeco::GetCreator({1, 2})}},
84     {prim::kPrimGather->name(), {DependValueDeco::GetCreator({2})}},
85     {prim::kPrimAddN->name(), {UnfoldMakeTupleDeco::Creator}}};
86 
87   ExpanderPtr expander = init;
88   const auto iter = creators.find(GetCNodePrimitive(node)->name());
89   if (iter != creators.end()) {
90     expander = WrapExpander(expander, iter->second);
91   }
92   if (common::AnfAlgo::IsDynamicShape(node)) {
93     MS_LOG(INFO) << "try expander dynamic shape node.";
94     expander = SetDynamicShapeAttrDeco::Creator(expander);
95   }
96   return expander;
97 }
98 
GetExpander(const AnfNodePtr & node,bool abstract)99 ExpanderPtr GetExpander(const AnfNodePtr &node, bool abstract) {
100   ExpanderPtr expander =
101     abstract
102       ? std::make_shared<LitegraphExpander>(
103           std::static_pointer_cast<Callback>(std::make_shared<CallbackImplWithInferShape>()))
104       : std::make_shared<LitegraphExpander>(std::static_pointer_cast<Callback>(std::make_shared<CallbackImpl>()));
105   return GetExpander(node, expander);
106 }
107 
CanExpandFallback(const AnfNodePtr & node)108 bool CanExpandFallback(const AnfNodePtr &node) {
109   if (!node->isa<CNode>()) {
110     return false;
111   }
112   if (common::GetEnv("MS_DEV_EXPANDER_FALLBACK") == "off") {
113     return false;
114   }
115   // Operators with 'batch_rank' attribute, which only appears in the vmap scenario, are not supported currently.
116   if (common::AnfAlgo::HasNodeAttr(ops::kBatchRank, node->cast<CNodePtr>())) {
117     return false;
118   }
119   static const std::vector<OpWithLevel> expander_fallback_ops_with_level = {
120     {kAllTarget, OpLevel_0, prim::kPrimEqualCount},
121     {kAllTarget, OpLevel_0, prim::kPrimSoftsign},
122     {kAllTarget, OpLevel_0, prim::kPrimSquare},
123     {kAllTarget, OpLevel_0, prim::kPrimBiasAdd},
124     {kAllTarget, OpLevel_0, prim::kPrimReLU},
125     {kAllTarget, OpLevel_0, prim::kPrimRelu},
126     {kAllTarget, OpLevel_0, prim::kPrimSigmoid},
127     {kAllTarget, OpLevel_0, prim::kPrimBiasAdd},
128     {kAllTarget, OpLevel_0, prim::kPrimReLU},
129     {kAllTarget, OpLevel_0, prim::kPrimSoftplus},
130     {kAllTarget, OpLevel_0, prim::kPrimSoftplusGrad},
131     {kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
132     {kAllTarget, OpLevel_0, prim::kLambApplyOptimizerAssign},
133     {kAllTarget, OpLevel_0, prim::kLambApplyWeightAssign},
134     {kAllTarget, OpLevel_0, prim::kPrimAdamWeightDecay},
135     {kAllTarget, OpLevel_0, prim::kPrimStandardNormal},
136     {kAllTarget, OpLevel_0, prim::kPrimAdam},
137     // some ops including custom op are only used expand fallbak on Ascend.
138     {kAscendDevice, OpLevel_0, prim::kPrimSolveTriangular},
139     {kAscendDevice, OpLevel_0, prim::kPrimLU},
140     // disabled
141     {kAllTarget, OpLevel_1, prim::kPrimAddN},
142     {kAllTarget, OpLevel_1, prim::kPrimErfc},
143     {kAllTarget, OpLevel_1, prim::kPrimExpandDims},
144     {kAllTarget, OpLevel_1, prim::kPrimGeLU},
145     {kAllTarget, OpLevel_1, prim::kPrimGeLUGrad},
146     {kAllTarget, OpLevel_1, prim::kPrimSqrtGrad},
147     {kAllTarget, OpLevel_1, prim::kPrimTile},
148     {kAllTarget, OpLevel_1, prim::kPrimClipByNormNoDivSum},
149     {kAllTarget, OpLevel_1, prim::kSoftmaxGradExt},
150     {kAllTarget, OpLevel_1, prim::kFusedMulAdd},
151     {kAllTarget, OpLevel_1, prim::kPrimBatchMatMul},
152     {kAllTarget, OpLevel_1, prim::kPrimBiasAddGrad},
153     {kAllTarget, OpLevel_1, prim::kPrimDropout},
154     {kAllTarget, OpLevel_1, prim::kPrimDropoutGrad},
155     {kAllTarget, OpLevel_1, prim::kPrimMaximumGrad},
156     {kAllTarget, OpLevel_1, prim::kPrimMinimumGrad},
157     {kAllTarget, OpLevel_1, prim::kPrimLayerNorm},
158     {kAllTarget, OpLevel_1, prim::kPrimLayerNormGrad},
159     {kAllTarget, OpLevel_1, prim::kPrimLogSoftmax},
160     {kAllTarget, OpLevel_1, prim::kPrimLogSoftmaxV2},
161     {kAllTarget, OpLevel_1, prim::kPrimLogSoftmaxGrad},
162     {kAllTarget, OpLevel_1, prim::kPrimMatMul},
163     {kAllTarget, OpLevel_1, prim::kPrimReduceMean},
164     {kAllTarget, OpLevel_1, prim::kPrimReluGrad},
165     {kAllTarget, OpLevel_1, prim::kPrimSigmoidGrad},
166     {kAllTarget, OpLevel_1, prim::kPrimSigmoidCrossEntropyWithLogits},
167     {kAllTarget, OpLevel_1, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
168     {kAllTarget, OpLevel_1, prim::kPrimSlice},
169     {kAllTarget, OpLevel_1, prim::kPrimSoftmax},
170     {kAllTarget, OpLevel_1, prim::kPrimSoftmaxV2},
171     {kAllTarget, OpLevel_1, prim::kPrimSoftmaxCrossEntropyWithLogits},
172     {kAllTarget, OpLevel_1, prim::kPrimSquaredDifference},
173     {kAllTarget, OpLevel_1, prim::kPrimSqueeze},
174     {kAllTarget, OpLevel_1, prim::kPrimSquareSumAll},
175     {kAllTarget, OpLevel_1, prim::kPrimIdentityMath},
176     {kAllTarget, OpLevel_1, prim::kPrimOnesLike},
177     {kAllTarget, OpLevel_1, prim::kPrimBiasAddGrad},
178     {kAllTarget, OpLevel_1, prim::kPrimMaximumGrad},
179     {kAllTarget, OpLevel_1, prim::kPrimMinimumGrad},
180     {kAllTarget, OpLevel_1, prim::kPrimTanhGrad},
181   };
182   unsigned int op_level = (common::GetEnv("MS_DEV_EXPANDER_FALLBACK") == "1") ? 1 : 0;
183   auto ops = GkUtils::GetValidOps(expander_fallback_ops_with_level, op_level, {}, {}, {});
184   return std::any_of(ops.begin(), ops.end(),
185                      [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
186 }
187 
Run(const AnfNodePtr & node)188 AnfNodePtr ProcessCustomOpDeco::Run(const AnfNodePtr &node) {
189   if (node == nullptr) {
190     return nullptr;
191   }
192   auto new_node = decorated_->Run(node);
193   auto graph = GetCNodeFuncGraph(new_node);
194   if (graph == nullptr) {
195     return nullptr;
196   }
197   auto optimizer = std::make_shared<opt::GraphOptimizer>();
198   auto pm = std::make_shared<opt::PassManager>();
199   pm->AddPass(std::make_shared<opt::InplaceAssignForCustomOp>());
200   optimizer->AddPassManager(pm);
201   (void)optimizer->Optimize(graph);
202   return new_node;
203 }
204 
TryExpandCNode(const AnfNodePtr & node,const std::function<bool (const CNodePtr &)> & func)205 AnfNodePtr TryExpandCNode(const AnfNodePtr &node, const std::function<bool(const CNodePtr &)> &func) {
206   if (!CanExpandFallback(node)) {
207     return nullptr;
208   }
209   auto cnode = node->cast<CNodePtr>();
210   auto expander = GetExpander(node);
211   auto res = expander->Run(node);
212   auto expand_fg = GetCNodeFuncGraph(res);
213   if (expand_fg == nullptr) {
214     return nullptr;
215   }
216   // For Ascend, the selectkernel function may check and change the input Parameter of kernel,
217   // so we replace the inner Parameter to outer Parameter
218   bool need_replace_parameter = Callback::Instance()->GetTargetFromContext() == kAscendDevice;
219   std::map<AnfNodePtr, AnfNodePtr> param_map;
220   if (need_replace_parameter) {
221     auto &params = expand_fg->parameters();
222     for (size_t i = 0; i < params.size(); i++) {
223       if (cnode->input(i + 1)->isa<Parameter>()) {
224         param_map[params[i]] = cnode->input(i + 1);
225       }
226     }
227     need_replace_parameter = !param_map.empty();
228   }
229 
230   auto todos = TopoSort(expand_fg->output());
231   for (const auto &inner_node : todos) {
232     if (!AnfUtils::IsRealCNodeKernel(inner_node)) {
233       continue;
234     }
235     try {
236       MS_LOG_TRY_CATCH_SCOPE;
237       bool suc = false;
238       if (OpDefAdapter::NeedConvertGK2FE(inner_node)) {
239         (void)ConvertGraphKernelToFrontEnd::Process(inner_node);
240       }
241       auto inner_cnode = inner_node->cast<CNodePtr>();
242       if (need_replace_parameter) {
243         std::vector<std::pair<size_t, AnfNodePtr>> ori_input;
244         for (size_t i = 1; i < inner_cnode->size(); i++) {
245           auto iter = param_map.find(inner_cnode->input(i));
246           if (iter != param_map.end()) {
247             MS_LOG(DEBUG) << "Replace " << inner_cnode->input(i)->DebugString() << " by "
248                           << iter->second->DebugString();
249             (void)ori_input.emplace_back(i, inner_cnode->input(i));
250             inner_cnode->set_input(i, iter->second);
251           }
252         }
253         suc = func(inner_cnode);
254         // recover the origin inputs
255         for (auto &ori : ori_input) {
256           inner_cnode->set_input(ori.first, ori.second);
257         }
258       } else {
259         suc = func(inner_cnode);
260       }
261       if (!suc) {
262         MS_LOG(INFO) << "ExpanderFallback: select kernel [" << inner_node->fullname_with_scope() << "] failed.";
263         res = nullptr;
264         break;
265       }
266     } catch (std::exception &e) {
267       MS_LOG(WARNING) << "ExpanderFallback: error in select kernel for [" << inner_node->fullname_with_scope()
268                       << "], msg: " << e.what();
269       res = nullptr;
270       break;
271     }
272   }
273 #ifdef ENABLE_DUMP_IR
274   auto context = MsContext::GetInstance();
275   MS_EXCEPTION_IF_NULL(context);
276   if (context->CanDump(kAdvanced)) {
277     DumpIR("verbose_ir_files/expand_" + GetCNodeFuncName(node->cast<CNodePtr>()) + ".ir", expand_fg);
278   }
279 #endif
280   return res;
281 }
282 
SetDynamicShapeAttrToCNode(const CNodePtr & cnode)283 void SetDynamicShapeAttrToCNode(const CNodePtr &cnode) {
284   auto in_dynamic = common::AnfAlgo::IsNodeInputDynamicShape(cnode);
285   auto out_dynamic = common::AnfAlgo::IsNodeOutputDynamicShape(cnode);
286   if (in_dynamic && !common::AnfAlgo::HasNodeAttr(kAttrInputIsDynamicShape, cnode)) {
287     common::AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode);
288   }
289   if (out_dynamic && !common::AnfAlgo::HasNodeAttr(kAttrOutputIsDynamicShape, cnode)) {
290     common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
291   }
292 }
293 
SetDynamicShapeAttr(const FuncGraphPtr & graph)294 void SetDynamicShapeAttr(const FuncGraphPtr &graph) {
295   auto todos = TopoSort(graph->get_return());
296   for (const auto &node : todos) {
297     if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
298       continue;
299     }
300     auto cnode = dyn_cast<CNode>(node);
301     SetDynamicShapeAttrToCNode(cnode);
302   }
303 }
304 
Run(const AnfNodePtr & node)305 AnfNodePtr SetDynamicShapeAttrDeco::Run(const AnfNodePtr &node) {
306   auto new_node = decorated_->Run(node);
307   if (new_node == nullptr) {
308     return nullptr;
309   }
310   auto new_cnode = dyn_cast<CNode>(new_node);
311   auto expand_fg = GetCNodeFuncGraph(new_cnode);
312   SetDynamicShapeAttr(expand_fg);
313   new_cnode->set_input(0, NewValueNode(expand_fg));
314   return new_cnode;
315 }
316 
Run(const AnfNodePtr & node)317 AnfNodePtr ComplexOpDecorator::Run(const AnfNodePtr &node) {
318   auto cnode = QuickCloneCNode(node);
319   auto prim = GetCNodePrimitive(cnode);
320   MS_EXCEPTION_IF_NULL(prim);
321   cnode->set_input(0, NewValueNode(std::make_shared<Primitive>("C" + prim->name(), prim->attrs())));
322   return decorated_->Run(cnode);
323 }
324 
325 // Used for ArgMaxWithValue(ArgMinWithValue) which output is tuple(index,value)
326 // Currently only expand it when output[1] has users and output[0] has no users
327 // In this case, ArgMaxWithValue(ArgMinWithValue) can be converted to ReduceMax(ReduceMin)
328 // If output[0] has users, expanding is not allowed
Run(const AnfNodePtr & node)329 AnfNodePtr ArgWithValueDeco::Run(const AnfNodePtr &node) {
330   auto mng = GkUtils::GetFuncGraphManager(node->func_graph());
331   bool res = false;
332   if (auto iter = mng->node_users().find(node); iter != mng->node_users().end()) {
333     auto output_info_list = iter->second;
334     res = std::all_of(output_info_list.begin(), output_info_list.end(), [](const std::pair<AnfNodePtr, int> &info) {
335       if (IsPrimitiveCNode(info.first, prim::kPrimTupleGetItem)) {
336         const auto &cnode = info.first->cast<CNodePtr>();
337         auto value_ptr = GetValueNode(cnode->input(kInputNodeOutputIndexInTupleGetItem));
338         MS_EXCEPTION_IF_NULL(value_ptr);
339         return GetValue<int64_t>(value_ptr) == 1;
340       }
341       return false;
342     });
343   }
344   return res ? decorated_->Run(node) : nullptr;
345 }
346 
Run(const AnfNodePtr & node)347 AnfNodePtr UnfoldMakeTupleDeco::Run(const AnfNodePtr &node) {
348   auto cnode = node->cast<CNodePtr>();
349   MS_EXCEPTION_IF_NULL(cnode);
350   if (cnode->size() == kIndex2 && IsPrimitiveCNode(cnode->input(1), prim::kPrimMakeTuple)) {
351     auto make_tupe_cnode = cnode->input(1)->cast<CNodePtr>();
352     MS_EXCEPTION_IF_NULL(make_tupe_cnode);
353     std::vector<AnfNodePtr> new_inputs;
354     new_inputs.push_back(cnode->input(0));
355     for (size_t i = 1; i < make_tupe_cnode->size(); ++i) {
356       new_inputs.push_back(make_tupe_cnode->input(i));
357     }
358     cnode = QuickCloneCNode(cnode);
359     cnode->set_inputs(new_inputs);
360   }
361   return decorated_->Run(cnode);
362 }
363 
InlineExpandFuncGraph(const AnfNodePtr & expanding_node,const FuncGraphPtr & expanded_graph)364 void InlineExpandFuncGraph(const AnfNodePtr &expanding_node, const FuncGraphPtr &expanded_graph) {
365   auto main_graph = expanding_node->func_graph();
366   auto mng = main_graph->manager();
367   if (mng == nullptr) {
368     mng = Manage(main_graph, true);
369     main_graph->set_manager(mng);
370   }
371   auto cnode = expanding_node->cast<CNodePtr>();
372   MS_EXCEPTION_IF_NULL(cnode);
373   AnfNodePtrList inp(cnode->inputs().begin() + 1, cnode->inputs().end());
374   auto out = InlineClone(expanded_graph, main_graph, inp, cnode);
375   (void)mng->Replace(expanding_node, out);
376 }
377 
IsComplexOp(const AnfNodePtr & node)378 bool IsComplexOp(const AnfNodePtr &node) {
379   auto cnode = node->cast<CNodePtr>();
380   MS_EXCEPTION_IF_NULL(cnode);
381   for (size_t i = 1; i < cnode->size(); i++) {
382     auto input = cnode->input(i);
383     TypePtr input_type = input->Type();
384     if (input_type == nullptr || !input_type->isa<TensorType>()) {
385       return false;
386     }
387     input_type = input_type->cast<TensorTypePtr>()->element();
388     if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
389       return true;
390     }
391   }
392   return false;
393 }
394 }  // namespace mindspore::graphkernel
395