1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/adapter/graph_kernel_expander_cloud.h"
18 #include <set>
19 #include "mindspore/core/ops/random_ops.h"
20 #include "mindspore/core/ops/nn_optimizer_ops.h"
21 #include "mindspore/core/ops/nn_ops.h"
22 #include "mindspore/core/ops/math_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/comparison_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "include/common/utils/anfalgo.h"
28 #include "utils/ms_context.h"
29 #include "backend/common/graph_kernel/graph_kernel_flags.h"
30 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
31 #include "backend/common/graph_kernel/graph_kernel_helper.h"
32 namespace mindspore::graphkernel {
33 namespace {
DvmSupported(const AnfNodePtr & node)34 bool DvmSupported(const AnfNodePtr &node) {
35 // check format
36 if (common::AnfAlgo::IsDynamicShape(node) && !CheckDefaultFormat(node)) {
37 // dvm kernel infer shape use inputs device shape, but the output abstract shape inferred from device shape is
38 // not unique if some shape value are not a multiple of 16
39 MS_LOG(DEBUG) << "skip node: " << node->fullname_with_scope()
40 << " because only default format is supported in dynamic shape";
41 return false;
42 }
43 // check data type
44 static std::set<TypeId> supported_types{kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeBool, kNumberTypeInt32,
45 kNumberTypeBFloat16};
46 if (IsPrimitiveCNode(node, prim::kPrimAddN)) {
47 constexpr auto max_input_num = 10;
48 auto input_num = common::AnfAlgo::GetInputTensorNum(node);
49 if (input_num > max_input_num) {
50 return false;
51 }
52 }
53 auto cb = Callback::Instance();
54 MS_EXCEPTION_IF_NULL(cb);
55 auto node_output_type = cb->GetOutputType(node, 0);
56 return supported_types.find(node_output_type) != supported_types.end();
57 }
58
59 const std::vector<OpWithLevel> expand_ops_with_level = {
60 {kAllTarget, OpLevel_0, prim::kPrimAddN},
61 {kAllTarget, OpLevel_0, prim::kPrimAssignAdd},
62 {kAllTarget, OpLevel_1, prim::kPrimExpandDims},
63 {kAllTarget, OpLevel_0, prim::kPrimGeLU},
64 {kAllTarget, OpLevel_0, prim::kPrimGelu},
65 {kAllTarget, OpLevel_0, prim::kPrimGeLUGrad},
66 {kAllTarget, OpLevel_0, prim::kPrimSqrtGrad},
67 {kAllTarget, OpLevel_0, prim::kPrimSquare},
68 {kAllTarget, OpLevel_0, prim::kPrimTile},
69 {kAscendDevice, OpLevel_0, prim::kLambApplyOptimizerAssign},
70 {kAscendDevice, OpLevel_0, prim::kLambApplyWeightAssign},
71 {kAscendDevice, OpLevel_0, prim::kPrimClipByNormNoDivSum},
72 {kAscendDevice, OpLevel_1, prim::kSoftmaxGradExt},
73 {kAscendDevice, OpLevel_0, prim::kFusedMulAdd},
74 {kGPUDevice, OpLevel_0, prim::kPrimErfc},
75 {kGPUDevice, OpLevel_1, prim::kPrimAdamWeightDecay},
76 {kGPUDevice, OpLevel_1, prim::kPrimBatchMatMul},
77 {kGPUDevice, OpLevel_0, prim::kPrimBiasAdd},
78 {kGPUDevice, OpLevel_1, prim::kPrimBiasAddGrad},
79 {kGPUDevice, OpLevel_0, prim::kPrimDropout},
80 {kGPUDevice, OpLevel_0, prim::kPrimDropoutGrad},
81 {kGPUDevice, OpLevel_1, prim::kPrimMaximumGrad},
82 {kGPUDevice, OpLevel_1, prim::kPrimMinimumGrad},
83 {kGPUDevice, OpLevel_1, prim::kPrimLayerNorm},
84 {kGPUDevice, OpLevel_1, prim::kPrimLayerNormGrad},
85 {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmax},
86 {kGPUDevice, OpLevel_0, prim::kPrimLogSoftmaxGrad},
87 {kGPUDevice, OpLevel_1, prim::kPrimMatMul},
88 {kGPUDevice, OpLevel_1, prim::kPrimReduceMean},
89 {kGPUDevice, OpLevel_1, prim::kPrimArgMaxWithValue},
90 {kGPUDevice, OpLevel_1, prim::kPrimArgMinWithValue},
91 {kGPUDevice, OpLevel_0, prim::kPrimReLU},
92 {kGPUDevice, OpLevel_0, prim::kPrimReluGrad},
93 {kGPUDevice, OpLevel_0, prim::kPrimSigmoid},
94 {kGPUDevice, OpLevel_0, prim::kPrimSigmoidGrad},
95 {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogits},
96 {kGPUDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
97 {kGPUDevice, OpLevel_0, prim::kPrimSlice},
98 {kGPUDevice, OpLevel_1, prim::kPrimSoftmax},
99 {kGPUDevice, OpLevel_1, prim::kPrimSoftmaxCrossEntropyWithLogits},
100 {kGPUDevice, OpLevel_0, prim::kPrimSquaredDifference},
101 {kGPUDevice, OpLevel_0, prim::kPrimSqueeze},
102 {kGPUDevice, OpLevel_0, prim::kPrimEqualCount},
103 {kGPUDevice, OpLevel_0, prim::kPrimSquareSumAll},
104 {kGPUDevice, OpLevel_0, prim::kPrimIdentityMath},
105 {kGPUDevice, OpLevel_0, prim::kPrimOnesLike},
106 {kGPUDevice, OpLevel_0, prim::kPrimStandardNormal},
107 {kCPUDevice, OpLevel_0, prim::kPrimOnesLike},
108 {kCPUDevice, OpLevel_0, prim::kPrimBiasAdd},
109 {kCPUDevice, OpLevel_1, prim::kPrimBiasAddGrad},
110 {kCPUDevice, OpLevel_0, prim::kPrimReLU},
111 {kCPUDevice, OpLevel_1, prim::kPrimMaximumGrad},
112 {kCPUDevice, OpLevel_1, prim::kPrimMinimumGrad},
113 {kCPUDevice, OpLevel_1, prim::kPrimAdam},
114 {kCPUDevice, OpLevel_1, prim::kPrimTanhGrad},
115 {kCPUDevice, OpLevel_1, prim::kPrimSoftplus},
116 {kCPUDevice, OpLevel_1, prim::kPrimSoftplusGrad},
117 };
118
119 const std::vector<OpWithLevel> expand_ops_with_level_v2 = {
120 // CPU
121 {kCPUDevice, OpLevel_0, prim::kPrimIdentityMath},
122 {kCPUDevice, OpLevel_0, prim::kPrimSqueeze},
123 {kCPUDevice, OpLevel_0, prim::kPrimSlice},
124
125 // GPU
126 {kGPUDevice, OpLevel_0, prim::kPrimBiasAdd},
127 {kGPUDevice, OpLevel_0, prim::kPrimDropout},
128 {kGPUDevice, OpLevel_0, prim::kPrimDropoutGrad},
129 {kGPUDevice, OpLevel_0, prim::kPrimLayerNorm},
130 {kGPUDevice, OpLevel_0, prim::kPrimLayerNormGrad},
131 {kGPUDevice, OpLevel_0, prim::kPrimRelu},
132 {kGPUDevice, OpLevel_0, prim::kPrimReluGrad},
133 {kGPUDevice, OpLevel_0, prim::kPrimClipByNorm},
134 };
135
136 const std::vector<OpWithLevel> expand_ops_with_level_dvm = {
137 {kAscendDevice, OpLevel_0, prim::kPrimAdam},
138 {kAscendDevice, OpLevel_0, prim::kPrimAddN},
139 {kAscendDevice, OpLevel_0, prim::kPrimBiasAdd},
140 {kAscendDevice, OpLevel_0, prim::kPrimBiasAddGrad},
141 {kAscendDevice, OpLevel_0, prim::kPrimFillV2},
142 {kAscendDevice, OpLevel_0, prim::kPrimGeLU},
143 {kAscendDevice, OpLevel_0, prim::kPrimGelu},
144 {kAscendDevice, OpLevel_0, prim::kPrimFastGelu},
145 {kAscendDevice, OpLevel_0, prim::kPrimFastGeluGrad},
146 {kAscendDevice, OpLevel_0, prim::kPrimFastGeLU},
147 {kAscendDevice, OpLevel_0, prim::kPrimFastGeLUGrad},
148 {kAscendDevice, OpLevel_0, prim::kPrimSiLU},
149 {kAscendDevice, OpLevel_0, prim::kPrimSiLUGrad},
150 {kAscendDevice, OpLevel_0, prim::kPrimGeLUGrad},
151 {kAscendDevice, OpLevel_0, prim::kPrimRsqrtGrad},
152 {kAscendDevice, OpLevel_0, prim::kPrimSqrtGrad},
153 {kAscendDevice, OpLevel_0, prim::kPrimSquare},
154 {kAscendDevice, OpLevel_0, prim::kPrimTile},
155 {kAscendDevice, OpLevel_0, prim::kPrimClipByNormNoDivSum},
156 {kAscendDevice, OpLevel_0, prim::kFusedMulAdd},
157 {kAscendDevice, OpLevel_0, prim::kPrimSigmoid},
158 {kAscendDevice, OpLevel_0, prim::kPrimSigmoidGrad},
159 {kAscendDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogits},
160 {kAscendDevice, OpLevel_0, prim::kPrimSigmoidCrossEntropyWithLogitsGrad},
161 {kAscendDevice, OpLevel_0, prim::kPrimSquaredDifference},
162 {kAscendDevice, OpLevel_0, prim::kPrimTanhGrad},
163 {kAscendDevice, OpLevel_0, prim::kPrimOnesLike},
164 {kAscendDevice, OpLevel_0, prim::kPrimZerosLike},
165 {kAscendDevice, OpLevel_0, prim::kPrimReduceMean},
166 {kAscendDevice, OpLevel_1, prim::kPrimLogSoftmaxGrad}, // will be split to multiple sub graphs because of ReduceSum
167 {kAscendDevice, OpLevel_0, prim::kPrimReLU},
168 {kAscendDevice, OpLevel_0, prim::kPrimReluGrad},
169 {kAscendDevice, OpLevel_0, prim::kPrimAssignAdd},
170 {kAscendDevice, OpLevel_0, prim::kLambApplyOptimizerAssign},
171 {kAscendDevice, OpLevel_0, prim::kLambApplyWeightAssign},
172 {kAscendDevice, OpLevel_0, prim::kPrimAdamApplyOneWithDecay},
173 {kAscendDevice, OpLevel_1, prim::kPrimExpandDims},
174 {kAscendDevice, OpLevel_1, prim::kPrimSqueeze},
175 {kAscendDevice, OpLevel_1, prim::kSoftmaxGradExt},
176 {kAscendDevice, OpLevel_1, prim::kPrimApplyMomentum},
177 };
178 } // namespace
179
GetExpanderOps()180 std::vector<PrimitivePtr> GraphKernelExpanderCloud::GetExpanderOps() {
181 const auto &flags = GraphKernelFlags::GetInstance();
182 std::vector<std::string> disable_expand_ops = flags.disable_expand_ops;
183 auto cb = Callback::Instance();
184
185 std::vector<OpWithLevel> expand_ops;
186 std::vector<std::string> disable_expand_op_list_v2 = {
187 "OnesLike", "OneHot", "StridedSlice", "CumSum", "Transpose", "BatchMatMul", "MatMul", "ExpandDims", "BroadcastTo"};
188 if (flags.kernel_generator == "AKG_V2") {
189 expand_ops = expand_ops_with_level;
190 expand_ops.insert(expand_ops.end(), expand_ops_with_level_v2.begin(), expand_ops_with_level_v2.end());
191 if (cb->GetTargetFromContext() == kGPUDevice) {
192 for (const std::string &item : disable_expand_op_list_v2) {
193 if (std::find(flags.enable_expand_ops.begin(), flags.enable_expand_ops.end(), item) ==
194 flags.enable_expand_ops.end()) {
195 disable_expand_ops.push_back(item);
196 }
197 }
198 }
199 } else if (flags.kernel_generator == "DVM") {
200 expand_ops = expand_ops_with_level_dvm;
201 } else {
202 expand_ops = expand_ops_with_level;
203 }
204 auto ops = GkUtils::GetValidOps(expand_ops, flags.fusion_ops_level, flags.enable_expand_ops_only,
205 flags.enable_expand_ops, disable_expand_ops);
206 return GkUtils::FilterExcludedOps(ops);
207 }
208
InitOpList()209 std::vector<PrimitivePtr> GraphKernelExpanderCloud::InitOpList() { return GraphKernelExpanderCloud::GetExpanderOps(); }
210
CanExpand(const CNodePtr & node) const211 bool GraphKernelExpanderCloud::CanExpand(const CNodePtr &node) const {
212 bool is_dvm = (GraphKernelFlags::GetInstance().kernel_generator == "DVM");
213 if (IsComplexOp(node) && !is_dvm) {
214 return true;
215 }
216 if (!GraphKernelExpander::CanExpand(node)) {
217 return false;
218 }
219 if (is_dvm && !DvmSupported(node)) {
220 return false;
221 }
222 if (!common::AnfAlgo::IsDynamicShape(node)) {
223 // for static cases, the node can be expanded if this is complex op
224 // or in the list
225 return true;
226 }
227
228 // deal wich dynamic cases
229 // the node with dyn rank will not be expand
230 if (common::AnfAlgo::IsDynamicRankNode(node)) {
231 return false;
232 }
233
234 auto enable_dynamic_shape = GraphKernelFlags::GetInstance().enable_dynamic_shape_fusion;
235 if (is_dvm) {
236 return enable_dynamic_shape;
237 }
238
239 std::vector<PrimitivePtr> expand_ops_dyn = {prim::kPrimReLU, prim::kPrimReluGrad, prim::kPrimBiasAdd,
240 prim::kPrimBiasAddGrad, prim::kPrimDropout};
241
242 bool dyn_can_expand_op = std::any_of(expand_ops_dyn.begin(), expand_ops_dyn.end(),
243 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
244 // the dyn shape node can be expanded
245 return (enable_dynamic_shape && dyn_can_expand_op);
246 }
247
InitExpander(const AnfNodePtr & node)248 ExpanderPtr GraphKernelExpanderCloud::InitExpander(const AnfNodePtr &node) {
249 auto e = GetExpander(node, std::make_shared<LitegraphExpander>(Callback::Instance()));
250 return e;
251 }
252 } // namespace mindspore::graphkernel
253