• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/adapter/graph_kernel_expander_cloud.h"
18 #include <set>
19 #include "mindspore/core/ops/random_ops.h"
20 #include "mindspore/core/ops/nn_optimizer_ops.h"
21 #include "mindspore/core/ops/nn_ops.h"
22 #include "mindspore/core/ops/math_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/comparison_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "include/common/utils/anfalgo.h"
28 #include "utils/ms_context.h"
29 #include "backend/common/graph_kernel/graph_kernel_flags.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
31 #include "backend/common/graph_kernel/graph_kernel_helper.h"
32 namespace mindspore::graphkernel {
33 namespace {
DvmSupported(const AnfNodePtr & node)34 bool DvmSupported(const AnfNodePtr &node) {
35   // check format
36   if (common::AnfAlgo::IsDynamicShape(node) && !CheckDefaultFormat(node)) {
37     // dvm kernel infer shape use inputs device shape, but the output abstract shape inferred from device shape is
38     // not unique if some shape value are not a multiple of 16
39     MS_LOG(DEBUG) << "skip node: " << node->fullname_with_scope()
40                   << " because only default format is supported in dynamic shape";
41     return false;
42   }
43   // check data type
44   static std::set<TypeId> supported_types{kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeBool, kNumberTypeInt32,
45                                           kNumberTypeBFloat16};
46   if (IsPrimitiveCNode(node, prim::kPrimAddN)) {
47     constexpr auto max_input_num = 10;
48     auto input_num = common::AnfAlgo::GetInputTensorNum(node);
49     if (input_num > max_input_num) {
50       return false;
51     }
52   }
53   auto cb = Callback::Instance();
54   MS_EXCEPTION_IF_NULL(cb);
55   auto node_output_type = cb->GetOutputType(node, 0);
56   return supported_types.find(node_output_type) != supported_types.end();
57 }
58 
59 const std::vector<OpWithLevel> expand_ops_with_level = {
60   {kAllTarget, OpLevel_0, prim::kPrimAddN},
61   {kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
62   {kAllTarget, OpLevel_1, prim::kPrimExpandDims},
63   {kAllTarget, OpLevel_0, prim::kPrimGeLU},
64   {kAllTarget, OpLevel_0, prim::kPrimGelu},
65   {kAllTarget, OpLevel_0, prim::kPrimGeLUGrad},
66   {kAllTarget, OpLevel_0, prim::kPrimSqrtGrad},
67   {kAllTarget, OpLevel_0, prim::kPrimSquare},
68   {kAllTarget, OpLevel_0, prim::kPrimTile},
69   {kAscendDevice, OpLevel_0, prim::kLambApplyOptimizerAssign},
70   {kAscendDevice, OpLevel_0, prim::kLambApplyWeightAssign},
71   {kAscendDevice, OpLevel_0, prim::kPrimClipByNormNoDivSum},
72   {kAscendDevice, OpLevel_1, prim::kSoftmaxGradExt},
73   {kAscendDevice, OpLevel_0, prim::kFusedMulAdd},
74   {kGPUDevice, OpLevel_0, prim::kPrimErfc},
75   {kGPUDevice, OpLevel_1, prim::kPrimAdamWeightDecay},
76   {kGPUDevice, OpLevel_1, prim::kPrimBatchMatMul},
77   {kGPUDevice, OpLevel_0, prim::kPrimBiasAdd},
78   {kGPUDevice, OpLevel_1, prim::kPrimBiasAddGrad},
79   {kGPUDevice, OpLevel_0, prim::kPrimDropout},
80   {kGPUDevice, OpLevel_0, prim::kPrimDropoutGrad},
81   {kGPUDevice, OpLevel_1, prim::kPrimMaximumGrad},
82   {kGPUDevice, OpLevel_1, prim::kPrimMinimumGrad},
83   {kGPUDevice, OpLevel_1, prim::kPrimLayerNorm},
84   {kGPUDevice, OpLevel_1, prim::kPrimLayerNormGrad},
85   {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmax},
86   {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmaxGrad},
87   {kGPUDevice, OpLevel_1, prim::kPrimMatMul},
88   {kGPUDevice, OpLevel_1, prim::kPrimReduceMean},
89   {kGPUDevice, OpLevel_1, prim::kPrimArgMaxWithValue},
90   {kGPUDevice, OpLevel_1, prim::kPrimArgMinWithValue},
91   {kGPUDevice, OpLevel_0, prim::kPrimReLU},
92   {kGPUDevice, OpLevel_0, prim::kPrimReluGrad},
93   {kGPUDevice, OpLevel_0, prim::kPrimSigmoid},
94   {kGPUDevice, OpLevel_0, prim::kPrimSigmoidGrad},
95   {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogits},
96   {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
97   {kGPUDevice, OpLevel_0, prim::kPrimSlice},
98   {kGPUDevice, OpLevel_1, prim::kPrimSoftmax},
99   {kGPUDevice, OpLevel_1, prim::kPrimSoftmaxCrossEntropyWithLogits},
100   {kGPUDevice, OpLevel_0, prim::kPrimSquaredDifference},
101   {kGPUDevice, OpLevel_0, prim::kPrimSqueeze},
102   {kGPUDevice, OpLevel_0, prim::kPrimEqualCount},
103   {kGPUDevice, OpLevel_0, prim::kPrimSquareSumAll},
104   {kGPUDevice, OpLevel_0, prim::kPrimIdentityMath},
105   {kGPUDevice, OpLevel_0, prim::kPrimOnesLike},
106   {kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
107   {kCPUDevice, OpLevel_0, prim::kPrimOnesLike},
108   {kCPUDevice, OpLevel_0, prim::kPrimBiasAdd},
109   {kCPUDevice, OpLevel_1, prim::kPrimBiasAddGrad},
110   {kCPUDevice, OpLevel_0, prim::kPrimReLU},
111   {kCPUDevice, OpLevel_1, prim::kPrimMaximumGrad},
112   {kCPUDevice, OpLevel_1, prim::kPrimMinimumGrad},
113   {kCPUDevice, OpLevel_1, prim::kPrimAdam},
114   {kCPUDevice, OpLevel_1, prim::kPrimTanhGrad},
115   {kCPUDevice, OpLevel_1, prim::kPrimSoftplus},
116   {kCPUDevice, OpLevel_1, prim::kPrimSoftplusGrad},
117 };
118 
119 const std::vector<OpWithLevel> expand_ops_with_level_v2 = {
120   // CPU
121   {kCPUDevice, OpLevel_0, prim::kPrimIdentityMath},
122   {kCPUDevice, OpLevel_0, prim::kPrimSqueeze},
123   {kCPUDevice, OpLevel_0, prim::kPrimSlice},
124 
125   // GPU
126   {kGPUDevice, OpLevel_0, prim::kPrimBiasAdd},
127   {kGPUDevice, OpLevel_0, prim::kPrimDropout},
128   {kGPUDevice, OpLevel_0, prim::kPrimDropoutGrad},
129   {kGPUDevice, OpLevel_0, prim::kPrimLayerNorm},
130   {kGPUDevice, OpLevel_0, prim::kPrimLayerNormGrad},
131   {kGPUDevice, OpLevel_0, prim::kPrimRelu},
132   {kGPUDevice, OpLevel_0, prim::kPrimReluGrad},
133   {kGPUDevice, OpLevel_0, prim::kPrimClipByNorm},
134 };
135 
136 const std::vector<OpWithLevel> expand_ops_with_level_dvm = {
137   {kAscendDevice, OpLevel_0, prim::kPrimAdam},
138   {kAscendDevice, OpLevel_0, prim::kPrimAddN},
139   {kAscendDevice, OpLevel_0, prim::kPrimBiasAdd},
140   {kAscendDevice, OpLevel_0, prim::kPrimBiasAddGrad},
141   {kAscendDevice, OpLevel_0, prim::kPrimFillV2},
142   {kAscendDevice, OpLevel_0, prim::kPrimGeLU},
143   {kAscendDevice, OpLevel_0, prim::kPrimGelu},
144   {kAscendDevice, OpLevel_0, prim::kPrimFastGelu},
145   {kAscendDevice, OpLevel_0, prim::kPrimFastGeluGrad},
146   {kAscendDevice, OpLevel_0, prim::kPrimFastGeLU},
147   {kAscendDevice, OpLevel_0, prim::kPrimFastGeLUGrad},
148   {kAscendDevice, OpLevel_0, prim::kPrimSiLU},
149   {kAscendDevice, OpLevel_0, prim::kPrimSiLUGrad},
150   {kAscendDevice, OpLevel_0, prim::kPrimGeLUGrad},
151   {kAscendDevice, OpLevel_0, prim::kPrimRsqrtGrad},
152   {kAscendDevice, OpLevel_0, prim::kPrimSqrtGrad},
153   {kAscendDevice, OpLevel_0, prim::kPrimSquare},
154   {kAscendDevice, OpLevel_0, prim::kPrimTile},
155   {kAscendDevice, OpLevel_0, prim::kPrimClipByNormNoDivSum},
156   {kAscendDevice, OpLevel_0, prim::kFusedMulAdd},
157   {kAscendDevice, OpLevel_0, prim::kPrimSigmoid},
158   {kAscendDevice, OpLevel_0, prim::kPrimSigmoidGrad},
159   {kAscendDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogits},
160   {kAscendDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
161   {kAscendDevice, OpLevel_0, prim::kPrimSquaredDifference},
162   {kAscendDevice, OpLevel_0, prim::kPrimTanhGrad},
163   {kAscendDevice, OpLevel_0, prim::kPrimOnesLike},
164   {kAscendDevice, OpLevel_0, prim::kPrimZerosLike},
165   {kAscendDevice, OpLevel_0, prim::kPrimReduceMean},
166   {kAscendDevice, OpLevel_1, prim::kPrimLogSoftmaxGrad},  // will be split to multiple sub graphs because of ReduceSum
167   {kAscendDevice, OpLevel_0, prim::kPrimReLU},
168   {kAscendDevice, OpLevel_0, prim::kPrimReluGrad},
169   {kAscendDevice, OpLevel_0, prim::kPrimAssignAdd},
170   {kAscendDevice, OpLevel_0, prim::kLambApplyOptimizerAssign},
171   {kAscendDevice, OpLevel_0, prim::kLambApplyWeightAssign},
172   {kAscendDevice, OpLevel_0, prim::kPrimAdamApplyOneWithDecay},
173   {kAscendDevice, OpLevel_1, prim::kPrimExpandDims},
174   {kAscendDevice, OpLevel_1, prim::kPrimSqueeze},
175   {kAscendDevice, OpLevel_1, prim::kSoftmaxGradExt},
176   {kAscendDevice, OpLevel_1, prim::kPrimApplyMomentum},
177 };
178 }  // namespace
179 
GetExpanderOps()180 std::vector<PrimitivePtr> GraphKernelExpanderCloud::GetExpanderOps() {
181   const auto &flags = GraphKernelFlags::GetInstance();
182   std::vector<std::string> disable_expand_ops = flags.disable_expand_ops;
183   auto cb = Callback::Instance();
184 
185   std::vector<OpWithLevel> expand_ops;
186   std::vector<std::string> disable_expand_op_list_v2 = {
187     "OnesLike", "OneHot", "StridedSlice", "CumSum", "Transpose", "BatchMatMul", "MatMul", "ExpandDims", "BroadcastTo"};
188   if (flags.kernel_generator == "AKG_V2") {
189     expand_ops = expand_ops_with_level;
190     expand_ops.insert(expand_ops.end(), expand_ops_with_level_v2.begin(), expand_ops_with_level_v2.end());
191     if (cb->GetTargetFromContext() == kGPUDevice) {
192       for (const std::string &item : disable_expand_op_list_v2) {
193         if (std::find(flags.enable_expand_ops.begin(), flags.enable_expand_ops.end(), item) ==
194             flags.enable_expand_ops.end()) {
195           disable_expand_ops.push_back(item);
196         }
197       }
198     }
199   } else if (flags.kernel_generator == "DVM") {
200     expand_ops = expand_ops_with_level_dvm;
201   } else {
202     expand_ops = expand_ops_with_level;
203   }
204   auto ops = GkUtils::GetValidOps(expand_ops, flags.fusion_ops_level, flags.enable_expand_ops_only,
205                                   flags.enable_expand_ops, disable_expand_ops);
206   return GkUtils::FilterExcludedOps(ops);
207 }
208 
InitOpList()209 std::vector<PrimitivePtr> GraphKernelExpanderCloud::InitOpList() { return GraphKernelExpanderCloud::GetExpanderOps(); }
210 
CanExpand(const CNodePtr & node) const211 bool GraphKernelExpanderCloud::CanExpand(const CNodePtr &node) const {
212   bool is_dvm = (GraphKernelFlags::GetInstance().kernel_generator == "DVM");
213   if (IsComplexOp(node) && !is_dvm) {
214     return true;
215   }
216   if (!GraphKernelExpander::CanExpand(node)) {
217     return false;
218   }
219   if (is_dvm && !DvmSupported(node)) {
220     return false;
221   }
222   if (!common::AnfAlgo::IsDynamicShape(node)) {
223     // for static cases, the node can be expanded if this is complex op
224     // or in the list
225     return true;
226   }
227 
228   // deal wich dynamic cases
229   // the node with dyn rank will not be expand
230   if (common::AnfAlgo::IsDynamicRankNode(node)) {
231     return false;
232   }
233 
234   auto enable_dynamic_shape = GraphKernelFlags::GetInstance().enable_dynamic_shape_fusion;
235   if (is_dvm) {
236     return enable_dynamic_shape;
237   }
238 
239   std::vector<PrimitivePtr> expand_ops_dyn = {prim::kPrimReLU, prim::kPrimReluGrad, prim::kPrimBiasAdd,
240                                               prim::kPrimBiasAddGrad, prim::kPrimDropout};
241 
242   bool dyn_can_expand_op = std::any_of(expand_ops_dyn.begin(), expand_ops_dyn.end(),
243                                        [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
244   // the dyn shape node can be expanded
245   return (enable_dynamic_shape && dyn_can_expand_op);
246 }
247 
InitExpander(const AnfNodePtr & node)248 ExpanderPtr GraphKernelExpanderCloud::InitExpander(const AnfNodePtr &node) {
249   auto e = GetExpander(node, std::make_shared<LitegraphExpander>(Callback::Instance()));
250   return e;
251 }
252 }  // namespace mindspore::graphkernel
253