• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/norm_fusion.h"
19 #include <algorithm>
20 #include <memory>
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "ops/fusion/layer_norm_fusion.h"
25 #include "ops/fusion/reduce_fusion.h"
26 #include "mindspore/core/ops/instance_norm.h"
27 #include "include/common/utils/utils.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "securec/include/securec.h"
30 #include "nnacl/op_base.h"
31 #include "src/common/ops/anf_utils.h"
32 #include "ops/op_utils.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace {
GetReduceAxes(const BaseRef & n,std::vector<int> * axes)37 STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
38   MS_ASSERT(axes != nullptr);
39   if (utils::isa<ParameterPtr>(n)) {
40     auto axes_param = utils::cast<ParameterPtr>(n);
41     if (!axes_param->has_default() || axes_param->default_param() == nullptr) {
42       return lite::RET_NOT_SUPPORT;
43     }
44     auto axes_value = axes_param->default_param()->cast<tensor::TensorPtr>();
45     if (axes_value == nullptr) {
46       return lite::RET_ERROR;
47     }
48     if (axes_value->data_type() != kNumberTypeInt && axes_value->data_type() != kNumberTypeInt32) {
49       MS_LOG(ERROR) << "reduce's axes should be integer, now is " << axes_value->data_type();
50       return lite::RET_ERROR;
51     }
52     if (axes_value->data_c() == nullptr) {
53       return lite::RET_ERROR;
54     }
55     if (axes_value->shape().size() > 1) {
56       return lite::RET_ERROR;
57     }
58     axes->resize(1);
59     if (!axes_value->shape().empty()) {
60       MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
61       axes->resize(static_cast<size_t>(axes_value->shape()[0]));
62     }
63     if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
64       return lite::RET_OK;
65     }
66   }
67   if (utils::isa<ValueNodePtr>(n)) {
68     auto axes_value_node = utils::cast<ValueNodePtr>(n);
69     *axes = CastToInt(axes_value_node->value());
70     return lite::RET_OK;
71   }
72   return lite::RET_ERROR;
73 }
74 
IsReduceNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)75 bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes, std::vector<int> *axes) {
76   MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
77   auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
78   MS_ASSERT(reduce_value != nullptr);
79   auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
80   MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
81   auto mean2_primitive_c = mean2_primitive->GetPrim();
82   if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Mean) {
83     return false;
84   }
85   if (GetReduceAxes((*equiv)[input_axes], axes) != lite::RET_OK) {
86     return false;
87   }
88   return true;
89 }
90 }  // namespace
91 
Init() const92 bool NormFusion::Init() const {
93   input_ = std::make_shared<Var>();
94   MS_CHECK_TRUE_RET(input_ != nullptr, false);
95   mean1_ = std::make_shared<Var>();
96   MS_CHECK_TRUE_RET(mean1_ != nullptr, false);
97   mean1_axes_ = std::make_shared<Var>();
98   MS_CHECK_TRUE_RET(mean1_axes_ != nullptr, false);
99   mean2_ = std::make_shared<Var>();
100   MS_CHECK_TRUE_RET(mean2_ != nullptr, false);
101   mean2_axes_ = std::make_shared<Var>();
102   MS_CHECK_TRUE_RET(mean2_axes_ != nullptr, false);
103   gamma_ = std::make_shared<Var>();
104   MS_CHECK_TRUE_RET(gamma_ != nullptr, false);
105   beta_ = std::make_shared<Var>();
106   MS_CHECK_TRUE_RET(beta_ != nullptr, false);
107   epsilon_ = std::make_shared<Var>();
108   MS_CHECK_TRUE_RET(epsilon_ != nullptr, false);
109   return true;
110 }
111 
CreateNormNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const schema::PrimitiveType type,float epsilon,int begin_norm_axis,int begin_params_axis) const112 CNodePtr NormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
113                                     const schema::PrimitiveType type, float epsilon, int begin_norm_axis,
114                                     int begin_params_axis) const {
115   MS_ASSERT(func_graph != nullptr);
116   MS_ASSERT(equiv != nullptr);
117   PrimitiveCPtr primitive_c = nullptr;
118   if (type == schema::PrimitiveType_LayerNormFusion) {
119     auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>();
120     MS_CHECK_TRUE_RET(layer_norm_primitive != nullptr, nullptr);
121     layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true);
122     auto layer_norm_primitive_c = layer_norm_primitive->GetPrim();
123     MS_CHECK_TRUE_RET(layer_norm_primitive_c != nullptr, nullptr);
124     primitive_c = layer_norm_primitive_c;
125   } else if (type == schema::PrimitiveType_InstanceNorm) {
126     auto instance_norm_primitive = std::make_shared<ops::InstanceNorm>();
127     MS_CHECK_TRUE_RET(instance_norm_primitive != nullptr, nullptr);
128     auto instance_norm_primitive_c = instance_norm_primitive->GetPrim();
129     MS_CHECK_TRUE_RET(instance_norm_primitive_c != nullptr, nullptr);
130     instance_norm_primitive->Init(epsilon);
131     primitive_c = instance_norm_primitive_c;
132   } else {
133     return nullptr;
134   }
135   auto value_node = NewValueNode(primitive_c);
136   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
137   std::vector<AnfNodePtr> new_node_inputs = {value_node};
138   auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
139   MS_ASSERT(input_node != nullptr);
140   new_node_inputs.push_back(input_node);
141   auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
142   MS_ASSERT(gamma_node != nullptr);
143   new_node_inputs.push_back(gamma_node);
144   auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
145   MS_ASSERT(beta_node != nullptr);
146   new_node_inputs.push_back(beta_node);
147   auto new_node = func_graph->NewCNode(new_node_inputs);
148   return new_node;
149 }
150 
GetNormTypeAndAxis(const FuncGraphPtr & func_graph,const CNodePtr & input_cnode,const std::vector<int> & mean_axes,const std::vector<int> & params_shape,schema::PrimitiveType * type,int * begin_norm_axis,int * begin_params_axis) const151 bool NormFusion::GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
152                                     const std::vector<int> &mean_axes, const std::vector<int> &params_shape,
153                                     schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const {
154   MS_ASSERT(func_graph != nullptr);
155   MS_ASSERT(input_cnode != nullptr);
156   MS_ASSERT(type != nullptr);
157   MS_ASSERT(begin_norm_axis != nullptr);
158   MS_ASSERT(begin_params_axis != nullptr);
159   auto abstract = input_cnode->abstract();
160   if (abstract == nullptr) {
161     MS_LOG(DEBUG) << "abstract of input is nullptr";
162     return false;
163   }
164   ShapeVector shape;
165   if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
166     MS_LOG(ERROR) << "fetch shape failed.";
167     return false;
168   }
169   int shape_size = static_cast<int>(shape.size());
170   if (lite::JudgeDynamicShape(shape)) {
171     auto shape_size_map = ShapeSizeInfer(func_graph);
172     if (shape_size_map.find(input_cnode->fullname_with_scope()) != shape_size_map.end()) {
173       shape_size = shape_size_map[input_cnode->fullname_with_scope()];
174     }
175   }
176 
177   for (size_t i = 1; i < mean_axes.size(); ++i) {
178     if (mean_axes[i] != mean_axes[i - 1] + 1) {
179       MS_LOG(DEBUG) << "mean axes is not continuous";
180       return false;
181     }
182   }
183   // shape input has 4 dim && mean input has 2 dim and mean is in [1, 2 ,...]
184   if (shape_size == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) {
185     if (params_shape.size() == 1 && params_shape.back() == shape.back()) {
186       *type = schema::PrimitiveType_InstanceNorm;
187       return true;
188     }
189   }
190   if (mean_axes.back() >= 0 && mean_axes.back() + 1 != shape_size) {
191     MS_LOG(DEBUG) << "mean node is not reduce to last axis.";
192     return false;
193   }
194 
195   // there is no need to check params_shape
196   *begin_norm_axis = mean_axes.front();
197   if (*begin_norm_axis >= 0) {
198     *begin_params_axis = shape_size - static_cast<int>(params_shape.size());
199     if (*begin_params_axis < 0) {
200       MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse";
201       return false;
202     }
203   } else {
204     *begin_params_axis = -static_cast<int>(params_shape.size());
205   }
206 
207   *type = schema::PrimitiveType_LayerNormFusion;
208   return true;
209 }
210 
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,schema::PrimitiveType * type,float * epsilon,int * begin_norm_axis,int * begin_params_axis) const211 bool NormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type,
212                               float *epsilon, int *begin_norm_axis, int *begin_params_axis) const {
213   MS_ASSERT(equiv != nullptr);
214   MS_ASSERT(epsilon != nullptr);
215   MS_ASSERT(type != nullptr);
216   MS_ASSERT(begin_norm_axis != nullptr);
217   MS_ASSERT(begin_params_axis != nullptr);
218   // beta
219   auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
220   MS_ASSERT(beta_node != nullptr);
221   if (!beta_node->isa<Parameter>()) {
222     return false;
223   }
224   auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
225   MS_CHECK_TRUE_RET(beta_param != nullptr, false);
226   auto beta_tensor = beta_param->cast<tensor::TensorPtr>();
227   MS_CHECK_TRUE_RET(beta_tensor != nullptr, false);
228   std::vector<int> beta_shape;
229   std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape),
230                  [](int64_t val) { return static_cast<int>(val); });
231   // gamma
232   auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
233   MS_ASSERT(gamma_node != nullptr);
234   if (!gamma_node->isa<Parameter>()) {
235     return false;
236   }
237   auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
238   MS_CHECK_TRUE_RET(gamma_param != nullptr, false);
239   auto gamma_tensor = gamma_param->cast<tensor::TensorPtr>();
240   MS_CHECK_TRUE_RET(gamma_tensor != nullptr, false);
241   std::vector<int> gamma_shape;
242   std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape),
243                  [](int64_t val) { return static_cast<int>(val); });
244   // epsilon
245   auto epsilon_node = utils::cast<AnfNodePtr>((*equiv)[epsilon_]);
246   MS_ASSERT(epsilon_node != nullptr);
247   if (!epsilon_node->isa<Parameter>()) {
248     return false;
249   }
250   auto epsilon_param = epsilon_node->cast<ParameterPtr>()->default_param();
251   MS_CHECK_TRUE_RET(epsilon_param != nullptr, false);
252   auto epsilon_tensor = epsilon_param->cast<tensor::TensorPtr>();
253   MS_CHECK_TRUE_RET(epsilon_tensor != nullptr, false);
254   auto epsilon_shape = epsilon_tensor->shape();
255   // mean2
256   std::vector<int> mean2_axes;
257   if (!IsReduceNode(equiv, mean2_, mean2_axes_, &mean2_axes)) {
258     return false;
259   }
260   // mean1
261   std::vector<int> mean1_axes;
262   if (!IsReduceNode(equiv, mean1_, mean1_axes_, &mean1_axes)) {
263     return false;
264   }
265   auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
266   MS_ASSERT(input_node != nullptr);
267   if (!utils::isa<CNodePtr>(input_node)) {
268     return false;
269   }
270   auto input_cnode = input_node->cast<CNodePtr>();
271   if (mean1_axes != mean2_axes) {
272     return false;
273   }
274   if (gamma_shape != beta_shape) {
275     return false;
276   }
277   if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
278     MS_CHECK_TRUE_RET(epsilon_tensor->data_c() != nullptr, false);
279     auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->data_c());
280     *epsilon = epsilon_data[0];
281   } else {
282     return false;
283   }
284   return GetNormTypeAndAxis(func_graph, input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis);
285 }
286 
287 namespace {
CommonShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)288 int CommonShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
289   MS_ASSERT(in_shape_size.size() > 0);
290   return in_shape_size.at(0);
291 }
292 
ExpandDimsShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)293 int ExpandDimsShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
294   MS_ASSERT(in_shape_size.size() > 0);
295   return in_shape_size.at(0) + 1;
296 }
297 
StridedSliceShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)298 int StridedSliceShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
299   MS_ASSERT(in_shape_size.size() > 0);
300   MS_ASSERT(primitive.value.AsStridedSlice() != nullptr);
301   auto new_axis_mask = static_cast<size_t>(primitive.value.AsStridedSlice()->new_axis_mask);
302   auto add_dims = 0;
303   while (new_axis_mask != 0) {
304     new_axis_mask = (new_axis_mask - 1) & new_axis_mask;
305     add_dims++;
306   }
307   return in_shape_size.at(0) + add_dims;
308 }
309 
MatMulShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)310 int MatMulShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
311   MS_ASSERT(in_shape_size.size() > 1);
312   return in_shape_size[0];
313 }
314 
ReShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)315 int ReShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
316   MS_ASSERT(in_shape_size.size() > 1);
317   return in_shape_size[1];
318 }
319 
StackSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)320 int StackSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
321   MS_ASSERT(in_shape_size.size() > 1);
322   return std::accumulate(in_shape_size.begin(), in_shape_size.end(), 0);
323 }
324 
SqueezeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)325 int SqueezeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
326   MS_ASSERT(in_shape_size.size() > 0);
327   auto axis = primitive.value.AsSqueeze()->axis;
328   if (axis.empty()) {
329     return 0;
330   }
331   return in_shape_size.at(0) - axis.size();
332 }
333 
OneHotSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)334 int OneHotSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
335   MS_ASSERT(in_shape_size.size() > 0);
336   return in_shape_size.at(0) + 1;
337 }
338 
FillShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)339 int FillShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
340   MS_ASSERT(in_shape_size.size() > 1);
341   return in_shape_size.at(1);
342 }
343 
ShapeOpSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)344 int ShapeOpSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) { return 1; }
345 
BroadcastShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)346 int BroadcastShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
347   MS_ASSERT(in_shape_size.size() > 1);
348   int result = 0;
349   for (auto shape_size : in_shape_size) {
350     result = std::max(result, shape_size);
351   }
352   return result;
353 }
354 }  // namespace
355 
ShapeSizeInfer(const FuncGraphPtr & func_graph) const356 std::map<string, int> NormFusion::ShapeSizeInfer(const FuncGraphPtr &func_graph) const {
357   MS_ASSERT(func_graph != nullptr);
358   std::map<string, int> node_shape_size;
359   std::map<string, std::vector<int>> node_shape;
360   auto node_list = TopoSort(func_graph->get_return());
361   for (auto &node : node_list) {
362     if (!utils::isa<CNodePtr>(node)) {
363       continue;
364     }
365     auto cnode = node->cast<CNodePtr>();
366     MS_ASSERT(cnode != nullptr);
367     auto prim_t = lite::GetPrimitiveT(cnode->input(0));
368     if (prim_t == nullptr) {
369       continue;
370     }
371     auto prim_type = prim_t->value.type;
372     auto shape_size_infer_iter = shape_size_infer_registry_.find(prim_type);
373     if (shape_size_infer_iter == shape_size_infer_registry_.end()) {
374       continue;
375     }
376 
377     // specific op infer shape
378     if (prim_type == schema::PrimitiveType_Shape) {
379       tensor::TensorPtr tensor_info;
380       auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, 1);
381       if (ret == RET_OK) {
382         node_shape[cnode->fullname_with_scope()] = {static_cast<int>(tensor_info->shape().size())};
383       } else if (node_shape_size.find(cnode->input(1)->fullname_with_scope()) != node_shape_size.end()) {
384         node_shape[cnode->fullname_with_scope()] = {node_shape_size[cnode->input(1)->fullname_with_scope()]};
385       }
386     } else if (prim_type == schema::PrimitiveType_StridedSlice) {
387       node_shape[cnode->fullname_with_scope()] = node_shape[cnode->input(1)->fullname_with_scope()];
388     } else if (prim_type == schema::PrimitiveType_Stack) {
389       auto shape = node_shape[cnode->input(1)->fullname_with_scope()];
390       shape.insert(shape.begin(), cnode->size() - 1);
391       node_shape[cnode->fullname_with_scope()] = shape;
392     }
393 
394     // Get in node shape size
395     std::vector<int> in_shape_sizes;
396     for (size_t i = 1; i < cnode->size(); i++) {
397       int in_shape_size = 0;
398       if (utils::isa<CNodePtr>(cnode->input(i))) {
399         in_shape_size = node_shape_size[cnode->input(i)->fullname_with_scope()];
400         // second input of reshape is shape
401         if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT &&
402             node_shape.find(cnode->input(i)->fullname_with_scope()) != node_shape.end()) {
403           in_shape_size = node_shape[cnode->input(i)->fullname_with_scope()].at(0);
404         }
405       } else {
406         tensor::TensorPtr tensor_info;
407         auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, i);
408         if (ret == RET_OK) {
409           in_shape_size = tensor_info->shape().size();
410           // second input of reshape is shape
411           if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT) {
412             in_shape_size = tensor_info->shape().at(0);
413           }
414         }
415       }
416       in_shape_sizes.emplace_back(in_shape_size);
417     }
418     // Cal shape size infer function
419     auto shape_size_infer_func = shape_size_infer_iter->second;
420     auto shape_size = shape_size_infer_func(in_shape_sizes, *prim_t);
421     // Update node shape size map
422     node_shape_size[cnode->fullname_with_scope()] = shape_size;
423   }
424   return node_shape_size;
425 }
426 
InitShapeSizeInferFuncMap()427 void NormFusion::InitShapeSizeInferFuncMap() {
428   if (!shape_size_infer_registry_.empty()) {
429     return;
430   }
431   shape_size_infer_registry_[schema::PrimitiveType_Activation] = CommonShapeSizeInfer;
432   shape_size_infer_registry_[schema::PrimitiveType_AddFusion] = BroadcastShapeSizeInfer;
433   shape_size_infer_registry_[schema::PrimitiveType_BiasAdd] = CommonShapeSizeInfer;
434   shape_size_infer_registry_[schema::PrimitiveType_Stack] = StackSizeInfer;
435   shape_size_infer_registry_[schema::PrimitiveType_Cast] = CommonShapeSizeInfer;
436   shape_size_infer_registry_[schema::PrimitiveType_Concat] = CommonShapeSizeInfer;
437   shape_size_infer_registry_[schema::PrimitiveType_ExpandDims] = ExpandDimsShapeSizeInfer;
438   shape_size_infer_registry_[schema::PrimitiveType_Fill] = FillShapeSizeInfer;
439   shape_size_infer_registry_[schema::PrimitiveType_LayerNormFusion] = CommonShapeSizeInfer;
440   shape_size_infer_registry_[schema::PrimitiveType_MatMulFusion] = MatMulShapeSizeInfer;
441   shape_size_infer_registry_[schema::PrimitiveType_MulFusion] = BroadcastShapeSizeInfer;
442   shape_size_infer_registry_[schema::PrimitiveType_OneHot] = OneHotSizeInfer;
443   shape_size_infer_registry_[schema::PrimitiveType_ReduceFusion] = CommonShapeSizeInfer;
444   shape_size_infer_registry_[schema::PrimitiveType_Reshape] = ReShapeSizeInfer;
445   shape_size_infer_registry_[schema::PrimitiveType_Shape] = ShapeOpSizeInfer;
446   shape_size_infer_registry_[schema::PrimitiveType_SliceFusion] = CommonShapeSizeInfer;
447   shape_size_infer_registry_[schema::PrimitiveType_Softmax] = CommonShapeSizeInfer;
448   shape_size_infer_registry_[schema::PrimitiveType_Squeeze] = SqueezeSizeInfer;
449   shape_size_infer_registry_[schema::PrimitiveType_StridedSlice] = StridedSliceShapeSizeInfer;
450   shape_size_infer_registry_[schema::PrimitiveType_Transpose] = CommonShapeSizeInfer;
451   shape_size_infer_registry_[schema::PrimitiveType_TileFusion] = CommonShapeSizeInfer;
452   shape_size_infer_registry_[schema::PrimitiveType_SquaredDifference] = CommonShapeSizeInfer;
453   shape_size_infer_registry_[schema::PrimitiveType_Rsqrt] = CommonShapeSizeInfer;
454   shape_size_infer_registry_[schema::PrimitiveType_SubFusion] = BroadcastShapeSizeInfer;
455   shape_size_infer_registry_[schema::PrimitiveType_PadFusion] = CommonShapeSizeInfer;
456   shape_size_infer_registry_[schema::PrimitiveType_PowFusion] = CommonShapeSizeInfer;
457 }
458 
CreateActivationNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node) const459 CNodePtr NormFusion::CreateActivationNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
460   auto act_primitive = std::make_shared<ops::Activation>();
461   MS_CHECK_TRUE_RET(act_primitive != nullptr, nullptr);
462   act_primitive->set_activation_type(add_act_type_);
463   auto act_primitive_c = act_primitive->GetPrim();
464   MS_CHECK_TRUE_RET(act_primitive_c != nullptr, nullptr);
465   auto value_node = NewValueNode(act_primitive_c);
466   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
467   std::vector<AnfNodePtr> new_node_inputs = {value_node};
468   new_node_inputs.push_back(node);
469   auto new_node = func_graph->NewCNode(new_node_inputs);
470   return new_node;
471 }
472 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const473 const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
474                                      const EquivPtr &equiv) const {
475   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
476     MS_LOG(ERROR) << "input param is nullptr, do norm fusion failed.";
477     return nullptr;
478   }
479   if (!utils::isa<CNodePtr>(node)) {
480     return nullptr;
481   }
482   auto add2_cnode = node->cast<CNodePtr>();
483   auto add2_primitive = ops::GetOperator<ops::AddFusion>(add2_cnode->input(0));
484   MS_CHECK_TRUE_RET(add2_primitive != nullptr, nullptr);
485   auto add2_primitive_c = add2_primitive->GetPrim();
486   if (add2_primitive_c->GetAttr(ops::kActivationType) != nullptr) {
487     add_act_type_ = add2_primitive->get_activation_type();
488   }
489   if (IsMarkedTrainOp(add2_cnode)) {
490     return nullptr;
491   }
492   float epsilon = 0.0f;
493   int begin_norm_axis = 0;
494   int begin_params_axis = 0;
495   schema::PrimitiveType type = schema::PrimitiveType_NONE;
496   if (!CheckPattern(func_graph, equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) {
497     return nullptr;
498   }
499   auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis);
500   if (norm_cnode == nullptr) {
501     MS_LOG(DEBUG) << "create norm cnode failed";
502     return nullptr;
503   }
504   MS_CHECK_TRUE_RET(add2_cnode->abstract() != nullptr, nullptr);
505   norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
506   if (type == schema::PrimitiveType_LayerNormFusion) {
507     norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
508     MS_LOG(DEBUG) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
509   } else if (type == schema::PrimitiveType_InstanceNorm) {
510     norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope());
511     MS_LOG(DEBUG) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
512   }
513   // insert relu node
514   if (add_act_type_ != ActivationType::NO_ACTIVATION) {
515     auto new_act_node = CreateActivationNode(func_graph, node);
516     if (new_act_node == nullptr) {
517       MS_LOG(ERROR) << "create act cnode failed.";
518       return nullptr;
519     }
520     new_act_node->set_fullname_with_scope("relu_" + add2_cnode->fullname_with_scope());
521     auto manager = func_graph->manager();
522     if (manager == nullptr) {
523       manager = Manage(func_graph, true);
524       MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
525     }
526     auto node_users = manager->node_users()[add2_cnode];
527     for (auto &node_user : node_users) {
528       auto next_cnode = node_user.first->cast<CNodePtr>();
529       MS_CHECK_TRUE_RET(next_cnode != nullptr, nullptr);
530       manager->SetEdge(next_cnode, node_user.second, new_act_node);
531     }
532     UpdateManager(func_graph);
533   }
534   return norm_cnode;
535 }
536 
DefinePattern() const537 const BaseRef TfNormFusion::DefinePattern() const {
538   if (!Init()) {
539     MS_LOG(ERROR) << "initial member failed.";
540     return {};
541   }
542   VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
543   auto is_squared_diffference = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquaredDifference>);
544   MS_CHECK_TRUE_RET(is_squared_diffference != nullptr, {});
545   VectorRef squared_diffference1_ref = VectorRef({is_squared_diffference, input_, mean1_ref});
546   VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_});
547   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
548   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
549   VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
550   auto is_rsqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRsqrt>);
551   MS_CHECK_TRUE_RET(is_rsqrt != nullptr, {});
552   VectorRef rsqrt1_ref = VectorRef({is_rsqrt, add1_ref});
553   auto is_mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
554   MS_CHECK_TRUE_RET(is_mul2 != nullptr, {});
555   VectorRef mul2_ref = VectorRef({is_mul2, rsqrt1_ref, gamma_});
556   auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
557   MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
558   VectorRef mul1_ref = VectorRef({is_mul1, input_, mul2_ref});
559   auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
560   MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
561   VectorRef mul3_ref = VectorRef({is_mul3, mean1_ref, mul2_ref});
562   auto is_sub = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
563   MS_CHECK_TRUE_RET(is_sub != nullptr, {});
564   VectorRef sub1_ref = VectorRef({is_sub, beta_, mul3_ref});
565   auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
566   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
567   VectorRef add2_ref = VectorRef({is_add2, mul1_ref, sub1_ref});
568   return add2_ref;
569 }
570 
DefinePattern() const571 const BaseRef OnnxLayerNormFusion::DefinePattern() const {
572   if (!Init()) {
573     MS_LOG(ERROR) << "initial member failed.";
574     return {};
575   }
576   VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
577   auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
578   MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
579   VectorRef sub1_ref = VectorRef({is_sub1, input_, mean1_ref});
580   auto is_sub2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
581   MS_CHECK_TRUE_RET(is_sub2 != nullptr, {});
582   VectorRef sub2_ref = VectorRef({is_sub2, input_, mean1_ref});
583   auto is_pow = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>);
584   MS_CHECK_TRUE_RET(is_pow != nullptr, {});
585   auto is_var = std::make_shared<Var>();
586   MS_CHECK_TRUE_RET(is_var != nullptr, {});
587   VectorRef pow_ref = VectorRef({is_pow, sub2_ref, is_var});
588   VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
589   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
590   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
591   VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
592   auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
593   MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
594   VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
595   auto is_div = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
596   MS_CHECK_TRUE_RET(is_div != nullptr, {});
597   VectorRef div_ref = VectorRef({is_div, sub1_ref, sqrt_ref});
598   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
599   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
600   VectorRef mul_ref = VectorRef({is_mul, gamma_, div_ref});
601   auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
602   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
603   VectorRef add2_ref = VectorRef({is_add2, mul_ref, beta_});
604   return add2_ref;
605 }
606 
607 // little different from OnnxLayerNormFusion on mul's inputs order
DefinePattern() const608 const BaseRef OnnxLayerNormFusion2::DefinePattern() const {
609   if (!Init()) {
610     MS_LOG(ERROR) << "initial member failed.";
611     return {};
612   }
613   VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
614   auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
615   MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
616   VectorRef sub1_ref = VectorRef({is_sub1, input_, mean1_ref});
617   auto is_sub2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
618   MS_CHECK_TRUE_RET(is_sub2 != nullptr, {});
619   VectorRef sub2_ref = VectorRef({is_sub2, input_, mean1_ref});
620   auto is_pow = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>);
621   MS_CHECK_TRUE_RET(is_pow != nullptr, {});
622   auto is_var = std::make_shared<Var>();
623   MS_CHECK_TRUE_RET(is_var != nullptr, {});
624   VectorRef pow_ref = VectorRef({is_pow, sub2_ref, is_var});
625   VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
626   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
627   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
628   VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
629   auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
630   MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
631   VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
632   auto is_div = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
633   MS_CHECK_TRUE_RET(is_div != nullptr, {});
634   VectorRef div_ref = VectorRef({is_div, sub1_ref, sqrt_ref});
635   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
636   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
637   VectorRef mul_ref = VectorRef({is_mul, div_ref, gamma_});
638   auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
639   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
640   VectorRef add2_ref = VectorRef({is_add2, mul_ref, beta_});
641   return add2_ref;
642 }
643 }  // namespace opt
644 }  // namespace mindspore
645