1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/groupnorm_fusion.h"
19 #include <algorithm>
20 #include <vector>
21 #include <memory>
22 #include "mindspore/core/ops/math_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "ops/fusion/groupnorm_fusion.h"
26 #include "include/common/utils/utils.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "securec/include/securec.h"
29 #include "nnacl/op_base.h"
30 #include "src/common/ops/ops_utils.h"
31 #include "ops/op_utils.h"
32
33 namespace mindspore {
34 namespace opt {
35 namespace {
GetAxis(const BaseRef & n,std::vector<int> * axes)36 STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
37 MS_ASSERT(axes != nullptr);
38 if (utils::isa<ParameterPtr>(n)) {
39 auto axes_param = utils::cast<ParameterPtr>(n);
40 if (!axes_param->has_default() || axes_param->default_param() == nullptr) {
41 return lite::RET_NOT_SUPPORT;
42 }
43 auto axes_value = axes_param->default_param()->cast<tensor::TensorPtr>();
44 if (axes_value == nullptr) {
45 return lite::RET_ERROR;
46 }
47 if (axes_value->data_type() != kNumberTypeInt && axes_value->data_type() != kNumberTypeInt32) {
48 MS_LOG(ERROR) << "reduce's axes should be integer, now is " << axes_value->data_type();
49 return lite::RET_ERROR;
50 }
51 if (axes_value->data_c() == nullptr) {
52 return lite::RET_ERROR;
53 }
54 if (axes_value->shape().size() > 1) {
55 return lite::RET_ERROR;
56 }
57 axes->resize(1);
58 if (!axes_value->shape().empty()) {
59 MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
60 axes->resize(static_cast<size_t>(axes_value->shape()[0]));
61 }
62 if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
63 return lite::RET_OK;
64 }
65 }
66 if (utils::isa<ValueNodePtr>(n)) {
67 auto axes_value_node = utils::cast<ValueNodePtr>(n);
68 *axes = CastToInt(axes_value_node->value());
69 return lite::RET_OK;
70 }
71 return lite::RET_ERROR;
72 }
73
IsReduceSumNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)74 bool IsReduceSumNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes,
75 std::vector<int> *axes) {
76 MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
77 auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
78 MS_ASSERT(reduce_value != nullptr);
79 auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
80 MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
81 auto mean2_primitive_c = mean2_primitive->GetPrim();
82 if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Sum) {
83 return false;
84 }
85 if (GetAxis((*equiv)[input_axes], axes) != lite::RET_OK) {
86 return false;
87 }
88 return true;
89 }
90
IsReduceMeanNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)91 bool IsReduceMeanNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes,
92 std::vector<int> *axes) {
93 MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
94 auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
95 MS_ASSERT(reduce_value != nullptr);
96 auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
97 MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
98 auto mean2_primitive_c = mean2_primitive->GetPrim();
99 if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Mean) {
100 return false;
101 }
102 if (GetAxis((*equiv)[input_axes], axes) != lite::RET_OK) {
103 return false;
104 }
105 return true;
106 }
107 } // namespace
108
Init() const109 bool GroupNormFusion::Init() const {
110 input_ = std::make_shared<Var>();
111 MS_CHECK_TRUE_RET(input_ != nullptr, false);
112 mean1_ = std::make_shared<Var>();
113 MS_CHECK_TRUE_RET(mean1_ != nullptr, false);
114 mean1_axis_ = std::make_shared<Var>();
115 MS_CHECK_TRUE_RET(mean1_axis_ != nullptr, false);
116 sum1_ = std::make_shared<Var>();
117 MS_CHECK_TRUE_RET(sum1_ != nullptr, false);
118 sum1_axis_ = std::make_shared<Var>();
119 MS_CHECK_TRUE_RET(sum1_axis_ != nullptr, false);
120 reshape1_axis_ = std::make_shared<Var>();
121 MS_CHECK_TRUE_RET(reshape1_axis_ != nullptr, false);
122 reshape2_axis_ = std::make_shared<Var>();
123 MS_CHECK_TRUE_RET(reshape2_axis_ != nullptr, false);
124 gamma_ = std::make_shared<Var>();
125 MS_CHECK_TRUE_RET(gamma_ != nullptr, false);
126 beta_ = std::make_shared<Var>();
127 MS_CHECK_TRUE_RET(beta_ != nullptr, false);
128 epsilon_ = std::make_shared<Var>();
129 MS_CHECK_TRUE_RET(epsilon_ != nullptr, false);
130 real_div_divider_ = std::make_shared<Var>();
131 MS_CHECK_TRUE_RET(real_div_divider_ != nullptr, false);
132
133 return true;
134 }
135
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int * num_groups,float * epsilon,bool * affine) const136 bool GroupNormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *num_groups,
137 float *epsilon, bool *affine) const {
138 MS_ASSERT(equiv != nullptr);
139 MS_ASSERT(epsilon != nullptr);
140 MS_ASSERT(num_groups != nullptr);
141 MS_ASSERT(epsilon != nullptr);
142 MS_ASSERT(affine != nullptr);
143
144 // beta
145 auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
146 MS_ASSERT(beta_node != nullptr);
147 if (!beta_node->isa<Parameter>()) {
148 return false;
149 }
150 auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
151 MS_CHECK_TRUE_RET(beta_param != nullptr, false);
152 auto beta_tensor = beta_param->cast<tensor::TensorPtr>();
153 MS_CHECK_TRUE_RET(beta_tensor != nullptr, false);
154 std::vector<int> beta_shape;
155 (void)std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape),
156 [](int64_t val) { return static_cast<int>(val); });
157 // gamma
158 auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
159 MS_ASSERT(gamma_node != nullptr);
160 if (!gamma_node->isa<Parameter>()) {
161 return false;
162 }
163 auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
164 MS_CHECK_TRUE_RET(gamma_param != nullptr, false);
165 auto gamma_tensor = gamma_param->cast<tensor::TensorPtr>();
166 MS_CHECK_TRUE_RET(gamma_tensor != nullptr, false);
167 std::vector<int> gamma_shape;
168 (void)std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape),
169 [](int64_t val) { return static_cast<int>(val); });
170 // epsilon
171 auto epsilon_node = utils::cast<AnfNodePtr>((*equiv)[epsilon_]);
172 MS_ASSERT(epsilon_node != nullptr);
173 if (!epsilon_node->isa<ValueNode>()) {
174 return false;
175 }
176 auto epsilon_value_node = epsilon_node->cast<ValueNodePtr>();
177 MS_CHECK_TRUE_RET(epsilon_value_node != nullptr, false);
178 auto epsilon_value = epsilon_value_node->value();
179 MS_CHECK_TRUE_RET(epsilon_value != nullptr, false);
180 if (!epsilon_value->isa<tensor::Tensor>()) {
181 std::cout << "CheckPattern:epsilon_value_node not tensor" << std::endl;
182 return false;
183 }
184 auto epsilon_tensor = epsilon_value->cast<tensor::TensorPtr>();
185 MS_CHECK_TRUE_RET(epsilon_tensor != nullptr, false);
186 TypeId tensor_type = epsilon_tensor->Dtype()->type_id();
187 if (!(tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
188 std::cout << "CheckPattern:epsilon_value_node not float" << std::endl;
189
190 return false;
191 }
192 auto epsilon_shape = epsilon_tensor->shape();
193 // sum1
194 std::vector<int> sum1_axes;
195 if (!IsReduceSumNode(equiv, sum1_, sum1_axis_, &sum1_axes)) {
196 return false;
197 }
198 // mean1
199 std::vector<int> mean1_axes;
200 if (!IsReduceMeanNode(equiv, mean1_, mean1_axis_, &mean1_axes)) {
201 return false;
202 }
203 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
204 MS_ASSERT(input_node != nullptr);
205 if (!utils::isa<CNodePtr>(input_node)) {
206 return false;
207 }
208 if (mean1_axes != sum1_axes) {
209 return false;
210 }
211 if (gamma_shape != beta_shape) {
212 return false;
213 }
214 if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
215 MS_CHECK_TRUE_RET(epsilon_tensor->data_c() != nullptr, false);
216 auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->data_c());
217 *epsilon = epsilon_data[0];
218 } else {
219 return false;
220 }
221 std::vector<int> reshape1_axes;
222 if (GetAxis((*equiv)[reshape1_axis_], &reshape1_axes) != lite::RET_OK) {
223 return false;
224 }
225 if (reshape1_axes.size() != C3NUM) {
226 return false;
227 }
228 *num_groups = reshape1_axes.at(C1NUM);
229 *affine = true;
230 return true;
231 }
232
CreateGroupNormNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int num_groups,float epsilon) const233 CNodePtr GroupNormFusion::CreateGroupNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int num_groups,
234 float epsilon) const {
235 MS_ASSERT(func_graph != nullptr);
236 MS_ASSERT(equiv != nullptr);
237 PrimitiveCPtr primitive_c = nullptr;
238
239 auto layer_norm_primitive = std::make_shared<ops::GroupNormFusion>();
240 MS_CHECK_TRUE_RET(layer_norm_primitive != nullptr, nullptr);
241 layer_norm_primitive->Init(num_groups, epsilon, true);
242 auto layer_norm_primitive_c = layer_norm_primitive->GetPrim();
243 MS_CHECK_TRUE_RET(layer_norm_primitive_c != nullptr, nullptr);
244 primitive_c = layer_norm_primitive_c;
245
246 auto value_node = NewValueNode(primitive_c);
247 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
248 std::vector<AnfNodePtr> new_node_inputs = {value_node};
249 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
250 MS_ASSERT(input_node != nullptr);
251 new_node_inputs.push_back(input_node);
252 auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
253 MS_ASSERT(gamma_node != nullptr);
254 new_node_inputs.push_back(gamma_node);
255 auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
256 MS_ASSERT(beta_node != nullptr);
257 new_node_inputs.push_back(beta_node);
258 auto new_node = func_graph->NewCNode(new_node_inputs);
259 return new_node;
260 }
261
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const262 const AnfNodePtr GroupNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
263 const EquivPtr &equiv) const {
264 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
265 MS_LOG(ERROR) << "input param is nullptr, do group norm fusion failed.";
266 return nullptr;
267 }
268 if (!utils::isa<CNodePtr>(node)) {
269 return nullptr;
270 }
271 auto add2_cnode = node->cast<CNodePtr>();
272 if (IsMarkedTrainOp(add2_cnode)) {
273 return nullptr;
274 }
275 float epsilon = 0.0f;
276 int num_groups = 0;
277 bool affine = true;
278 if (!CheckPattern(func_graph, equiv, &num_groups, &epsilon, &affine)) {
279 return nullptr;
280 }
281 auto norm_cnode = CreateGroupNormNode(func_graph, equiv, num_groups, epsilon);
282 if (norm_cnode == nullptr) {
283 MS_LOG(DEBUG) << "create norm cnode failed";
284 return nullptr;
285 }
286 MS_CHECK_TRUE_RET(add2_cnode->abstract() != nullptr, nullptr);
287 norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
288 norm_cnode->set_fullname_with_scope("group_norm_" + add2_cnode->fullname_with_scope());
289 MS_LOG(DEBUG) << "group_norm_ node:" << norm_cnode->fullname_with_scope() << " fusion success";
290 return norm_cnode;
291 }
292
DefinePattern() const293 const BaseRef GroupNormFusion::DefinePattern() const {
294 if (!Init()) {
295 MS_LOG(ERROR) << "initial member failed.";
296 return {};
297 }
298
299 auto is_reshape1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
300 MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
301 VectorRef reshape_ref1 = VectorRef({is_reshape1, input_, reshape1_axis_});
302 VectorRef mean1_ref = VectorRef({mean1_, reshape_ref1, mean1_axis_});
303 auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
304 MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
305 VectorRef sub1_ref = VectorRef({is_sub1, reshape_ref1, mean1_ref});
306
307 auto is_sqare = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquare>);
308 MS_CHECK_TRUE_RET(is_sqare != nullptr, {});
309 VectorRef square_ref = VectorRef({is_sqare, sub1_ref});
310 VectorRef sum1_ref = VectorRef({sum1_, square_ref, sum1_axis_});
311 auto is_realdiv1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
312 MS_CHECK_TRUE_RET(is_realdiv1 != nullptr, {});
313 VectorRef realdiv1_ref = VectorRef({is_realdiv1, sum1_ref, real_div_divider_});
314 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
315 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
316 VectorRef add1_ref = VectorRef({is_add1, realdiv1_ref, epsilon_});
317 auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
318 MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
319 VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
320 auto is_realdiv2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
321 MS_CHECK_TRUE_RET(is_realdiv2 != nullptr, {});
322 VectorRef realdiv2_ref = VectorRef({is_realdiv2, sub1_ref, sqrt_ref});
323
324 auto is_reshape2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
325 MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
326 VectorRef reshape_ref2 = VectorRef({is_reshape2, realdiv2_ref, reshape2_axis_});
327 auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
328 MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
329 VectorRef mul1_ref = VectorRef({is_mul1, reshape_ref2, gamma_});
330 auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
331 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
332 VectorRef add2_ref = VectorRef({is_add2, mul1_ref, beta_});
333 return add2_ref;
334 }
335 } // namespace opt
336 } // namespace mindspore
337