• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/groupnorm_fusion.h"
19 #include <algorithm>
20 #include <vector>
21 #include <memory>
22 #include "mindspore/core/ops/math_ops.h"
23 #include "mindspore/core/ops/lite_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "ops/fusion/groupnorm_fusion.h"
26 #include "include/common/utils/utils.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "securec/include/securec.h"
29 #include "nnacl/op_base.h"
30 #include "src/common/ops/ops_utils.h"
31 #include "ops/op_utils.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace {
GetAxis(const BaseRef & n,std::vector<int> * axes)36 STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
37   MS_ASSERT(axes != nullptr);
38   if (utils::isa<ParameterPtr>(n)) {
39     auto axes_param = utils::cast<ParameterPtr>(n);
40     if (!axes_param->has_default() || axes_param->default_param() == nullptr) {
41       return lite::RET_NOT_SUPPORT;
42     }
43     auto axes_value = axes_param->default_param()->cast<tensor::TensorPtr>();
44     if (axes_value == nullptr) {
45       return lite::RET_ERROR;
46     }
47     if (axes_value->data_type() != kNumberTypeInt && axes_value->data_type() != kNumberTypeInt32) {
48       MS_LOG(ERROR) << "reduce's axes should be integer, now is " << axes_value->data_type();
49       return lite::RET_ERROR;
50     }
51     if (axes_value->data_c() == nullptr) {
52       return lite::RET_ERROR;
53     }
54     if (axes_value->shape().size() > 1) {
55       return lite::RET_ERROR;
56     }
57     axes->resize(1);
58     if (!axes_value->shape().empty()) {
59       MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
60       axes->resize(static_cast<size_t>(axes_value->shape()[0]));
61     }
62     if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
63       return lite::RET_OK;
64     }
65   }
66   if (utils::isa<ValueNodePtr>(n)) {
67     auto axes_value_node = utils::cast<ValueNodePtr>(n);
68     *axes = CastToInt(axes_value_node->value());
69     return lite::RET_OK;
70   }
71   return lite::RET_ERROR;
72 }
73 
IsReduceSumNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)74 bool IsReduceSumNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes,
75                      std::vector<int> *axes) {
76   MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
77   auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
78   MS_ASSERT(reduce_value != nullptr);
79   auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
80   MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
81   auto mean2_primitive_c = mean2_primitive->GetPrim();
82   if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Sum) {
83     return false;
84   }
85   if (GetAxis((*equiv)[input_axes], axes) != lite::RET_OK) {
86     return false;
87   }
88   return true;
89 }
90 
IsReduceMeanNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)91 bool IsReduceMeanNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes,
92                       std::vector<int> *axes) {
93   MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
94   auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
95   MS_ASSERT(reduce_value != nullptr);
96   auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
97   MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
98   auto mean2_primitive_c = mean2_primitive->GetPrim();
99   if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Mean) {
100     return false;
101   }
102   if (GetAxis((*equiv)[input_axes], axes) != lite::RET_OK) {
103     return false;
104   }
105   return true;
106 }
107 }  // namespace
108 
Init() const109 bool GroupNormFusion::Init() const {
110   input_ = std::make_shared<Var>();
111   MS_CHECK_TRUE_RET(input_ != nullptr, false);
112   mean1_ = std::make_shared<Var>();
113   MS_CHECK_TRUE_RET(mean1_ != nullptr, false);
114   mean1_axis_ = std::make_shared<Var>();
115   MS_CHECK_TRUE_RET(mean1_axis_ != nullptr, false);
116   sum1_ = std::make_shared<Var>();
117   MS_CHECK_TRUE_RET(sum1_ != nullptr, false);
118   sum1_axis_ = std::make_shared<Var>();
119   MS_CHECK_TRUE_RET(sum1_axis_ != nullptr, false);
120   reshape1_axis_ = std::make_shared<Var>();
121   MS_CHECK_TRUE_RET(reshape1_axis_ != nullptr, false);
122   reshape2_axis_ = std::make_shared<Var>();
123   MS_CHECK_TRUE_RET(reshape2_axis_ != nullptr, false);
124   gamma_ = std::make_shared<Var>();
125   MS_CHECK_TRUE_RET(gamma_ != nullptr, false);
126   beta_ = std::make_shared<Var>();
127   MS_CHECK_TRUE_RET(beta_ != nullptr, false);
128   epsilon_ = std::make_shared<Var>();
129   MS_CHECK_TRUE_RET(epsilon_ != nullptr, false);
130   real_div_divider_ = std::make_shared<Var>();
131   MS_CHECK_TRUE_RET(real_div_divider_ != nullptr, false);
132 
133   return true;
134 }
135 
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int * num_groups,float * epsilon,bool * affine) const136 bool GroupNormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *num_groups,
137                                    float *epsilon, bool *affine) const {
138   MS_ASSERT(equiv != nullptr);
139   MS_ASSERT(epsilon != nullptr);
140   MS_ASSERT(num_groups != nullptr);
141   MS_ASSERT(epsilon != nullptr);
142   MS_ASSERT(affine != nullptr);
143 
144   // beta
145   auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
146   MS_ASSERT(beta_node != nullptr);
147   if (!beta_node->isa<Parameter>()) {
148     return false;
149   }
150   auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
151   MS_CHECK_TRUE_RET(beta_param != nullptr, false);
152   auto beta_tensor = beta_param->cast<tensor::TensorPtr>();
153   MS_CHECK_TRUE_RET(beta_tensor != nullptr, false);
154   std::vector<int> beta_shape;
155   (void)std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape),
156                        [](int64_t val) { return static_cast<int>(val); });
157   // gamma
158   auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
159   MS_ASSERT(gamma_node != nullptr);
160   if (!gamma_node->isa<Parameter>()) {
161     return false;
162   }
163   auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
164   MS_CHECK_TRUE_RET(gamma_param != nullptr, false);
165   auto gamma_tensor = gamma_param->cast<tensor::TensorPtr>();
166   MS_CHECK_TRUE_RET(gamma_tensor != nullptr, false);
167   std::vector<int> gamma_shape;
168   (void)std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape),
169                        [](int64_t val) { return static_cast<int>(val); });
170   // epsilon
171   auto epsilon_node = utils::cast<AnfNodePtr>((*equiv)[epsilon_]);
172   MS_ASSERT(epsilon_node != nullptr);
173   if (!epsilon_node->isa<ValueNode>()) {
174     return false;
175   }
176   auto epsilon_value_node = epsilon_node->cast<ValueNodePtr>();
177   MS_CHECK_TRUE_RET(epsilon_value_node != nullptr, false);
178   auto epsilon_value = epsilon_value_node->value();
179   MS_CHECK_TRUE_RET(epsilon_value != nullptr, false);
180   if (!epsilon_value->isa<tensor::Tensor>()) {
181     std::cout << "CheckPattern:epsilon_value_node not tensor" << std::endl;
182     return false;
183   }
184   auto epsilon_tensor = epsilon_value->cast<tensor::TensorPtr>();
185   MS_CHECK_TRUE_RET(epsilon_tensor != nullptr, false);
186   TypeId tensor_type = epsilon_tensor->Dtype()->type_id();
187   if (!(tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
188     std::cout << "CheckPattern:epsilon_value_node not float" << std::endl;
189 
190     return false;
191   }
192   auto epsilon_shape = epsilon_tensor->shape();
193   // sum1
194   std::vector<int> sum1_axes;
195   if (!IsReduceSumNode(equiv, sum1_, sum1_axis_, &sum1_axes)) {
196     return false;
197   }
198   // mean1
199   std::vector<int> mean1_axes;
200   if (!IsReduceMeanNode(equiv, mean1_, mean1_axis_, &mean1_axes)) {
201     return false;
202   }
203   auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
204   MS_ASSERT(input_node != nullptr);
205   if (!utils::isa<CNodePtr>(input_node)) {
206     return false;
207   }
208   if (mean1_axes != sum1_axes) {
209     return false;
210   }
211   if (gamma_shape != beta_shape) {
212     return false;
213   }
214   if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
215     MS_CHECK_TRUE_RET(epsilon_tensor->data_c() != nullptr, false);
216     auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->data_c());
217     *epsilon = epsilon_data[0];
218   } else {
219     return false;
220   }
221   std::vector<int> reshape1_axes;
222   if (GetAxis((*equiv)[reshape1_axis_], &reshape1_axes) != lite::RET_OK) {
223     return false;
224   }
225   if (reshape1_axes.size() != C3NUM) {
226     return false;
227   }
228   *num_groups = reshape1_axes.at(C1NUM);
229   *affine = true;
230   return true;
231 }
232 
CreateGroupNormNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int num_groups,float epsilon) const233 CNodePtr GroupNormFusion::CreateGroupNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int num_groups,
234                                               float epsilon) const {
235   MS_ASSERT(func_graph != nullptr);
236   MS_ASSERT(equiv != nullptr);
237   PrimitiveCPtr primitive_c = nullptr;
238 
239   auto layer_norm_primitive = std::make_shared<ops::GroupNormFusion>();
240   MS_CHECK_TRUE_RET(layer_norm_primitive != nullptr, nullptr);
241   layer_norm_primitive->Init(num_groups, epsilon, true);
242   auto layer_norm_primitive_c = layer_norm_primitive->GetPrim();
243   MS_CHECK_TRUE_RET(layer_norm_primitive_c != nullptr, nullptr);
244   primitive_c = layer_norm_primitive_c;
245 
246   auto value_node = NewValueNode(primitive_c);
247   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
248   std::vector<AnfNodePtr> new_node_inputs = {value_node};
249   auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
250   MS_ASSERT(input_node != nullptr);
251   new_node_inputs.push_back(input_node);
252   auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
253   MS_ASSERT(gamma_node != nullptr);
254   new_node_inputs.push_back(gamma_node);
255   auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
256   MS_ASSERT(beta_node != nullptr);
257   new_node_inputs.push_back(beta_node);
258   auto new_node = func_graph->NewCNode(new_node_inputs);
259   return new_node;
260 }
261 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const262 const AnfNodePtr GroupNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
263                                           const EquivPtr &equiv) const {
264   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
265     MS_LOG(ERROR) << "input param is nullptr, do group norm fusion failed.";
266     return nullptr;
267   }
268   if (!utils::isa<CNodePtr>(node)) {
269     return nullptr;
270   }
271   auto add2_cnode = node->cast<CNodePtr>();
272   if (IsMarkedTrainOp(add2_cnode)) {
273     return nullptr;
274   }
275   float epsilon = 0.0f;
276   int num_groups = 0;
277   bool affine = true;
278   if (!CheckPattern(func_graph, equiv, &num_groups, &epsilon, &affine)) {
279     return nullptr;
280   }
281   auto norm_cnode = CreateGroupNormNode(func_graph, equiv, num_groups, epsilon);
282   if (norm_cnode == nullptr) {
283     MS_LOG(DEBUG) << "create norm cnode failed";
284     return nullptr;
285   }
286   MS_CHECK_TRUE_RET(add2_cnode->abstract() != nullptr, nullptr);
287   norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
288   norm_cnode->set_fullname_with_scope("group_norm_" + add2_cnode->fullname_with_scope());
289   MS_LOG(DEBUG) << "group_norm_ node:" << norm_cnode->fullname_with_scope() << " fusion success";
290   return norm_cnode;
291 }
292 
DefinePattern() const293 const BaseRef GroupNormFusion::DefinePattern() const {
294   if (!Init()) {
295     MS_LOG(ERROR) << "initial member failed.";
296     return {};
297   }
298 
299   auto is_reshape1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
300   MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
301   VectorRef reshape_ref1 = VectorRef({is_reshape1, input_, reshape1_axis_});
302   VectorRef mean1_ref = VectorRef({mean1_, reshape_ref1, mean1_axis_});
303   auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
304   MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
305   VectorRef sub1_ref = VectorRef({is_sub1, reshape_ref1, mean1_ref});
306 
307   auto is_sqare = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquare>);
308   MS_CHECK_TRUE_RET(is_sqare != nullptr, {});
309   VectorRef square_ref = VectorRef({is_sqare, sub1_ref});
310   VectorRef sum1_ref = VectorRef({sum1_, square_ref, sum1_axis_});
311   auto is_realdiv1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
312   MS_CHECK_TRUE_RET(is_realdiv1 != nullptr, {});
313   VectorRef realdiv1_ref = VectorRef({is_realdiv1, sum1_ref, real_div_divider_});
314   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
315   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
316   VectorRef add1_ref = VectorRef({is_add1, realdiv1_ref, epsilon_});
317   auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
318   MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
319   VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
320   auto is_realdiv2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
321   MS_CHECK_TRUE_RET(is_realdiv2 != nullptr, {});
322   VectorRef realdiv2_ref = VectorRef({is_realdiv2, sub1_ref, sqrt_ref});
323 
324   auto is_reshape2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
325   MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
326   VectorRef reshape_ref2 = VectorRef({is_reshape2, realdiv2_ref, reshape2_axis_});
327   auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
328   MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
329   VectorRef mul1_ref = VectorRef({is_mul1, reshape_ref2, gamma_});
330   auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
331   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
332   VectorRef add2_ref = VectorRef({is_add2, mul1_ref, beta_});
333   return add2_ref;
334 }
335 }  // namespace opt
336 }  // namespace mindspore
337