• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/fusion/norm_fusion.h"
17 #include <algorithm>
18 #include <memory>
19 #include "ops/fusion/layer_norm_fusion.h"
20 #include "ops/fusion/reduce_fusion.h"
21 #include "mindspore/core/ops/instance_norm.h"
22 #include "utils/utils.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "securec/include/securec.h"
25 #include "nnacl/op_base.h"
26 #include "src/ops/ops_utils.h"
27 
28 namespace mindspore {
29 namespace opt {
30 namespace {
GetReduceAxes(const BaseRef & n,std::vector<int> * axes)31 STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
32   MS_ASSERT(axes != nullptr);
33   if (utils::isa<ParameterPtr>(n)) {
34     auto axes_param = utils::cast<ParameterPtr>(n);
35     if (!axes_param->has_default() || axes_param->default_param() == nullptr) {
36       return lite::RET_NOT_SUPPORT;
37     }
38     auto axes_value = axes_param->default_param()->cast<tensor::TensorPtr>();
39     if (axes_value == nullptr) {
40       return lite::RET_ERROR;
41     }
42     if (axes_value->data_type() != kNumberTypeInt && axes_value->data_type() != kNumberTypeInt32) {
43       MS_LOG(ERROR) << "reduce's axes should be integer, now is " << axes_value->data_type();
44       return lite::RET_ERROR;
45     }
46     if (axes_value->data_c() == nullptr) {
47       return lite::RET_ERROR;
48     }
49     if (axes_value->shape().size() > 1) {
50       return lite::RET_ERROR;
51     }
52     axes->resize(1);
53     if (!axes_value->shape().empty()) {
54       MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
55       axes->resize(static_cast<size_t>(axes_value->shape()[0]));
56     }
57     if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
58       return lite::RET_OK;
59     }
60   }
61   if (utils::isa<ValueNodePtr>(n)) {
62     auto axes_value_node = utils::cast<ValueNodePtr>(n);
63     *axes = CastToInt(axes_value_node->value());
64     return lite::RET_OK;
65   }
66   return lite::RET_ERROR;
67 }
68 
IsReduceNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)69 bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes, std::vector<int> *axes) {
70   MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
71   auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
72   MS_ASSERT(reduce_value != nullptr);
73   auto mean2_primitive = GetValueNode<std::shared_ptr<ops::ReduceFusion>>(reduce_value);
74   if (mean2_primitive == nullptr || mean2_primitive->GetAttr(ops::kMode) == nullptr ||
75       mean2_primitive->get_mode() != mindspore::Reduce_Mean) {
76     return false;
77   }
78   if (GetReduceAxes((*equiv)[input_axes], axes) != lite::RET_OK) {
79     return false;
80   }
81   return true;
82 }
83 }  // namespace
84 
Init() const85 bool NormFusion::Init() const {
86   input_ = std::make_shared<Var>();
87   MS_CHECK_TRUE_RET(input_ != nullptr, false);
88   mean1_ = std::make_shared<Var>();
89   MS_CHECK_TRUE_RET(mean1_ != nullptr, false);
90   mean1_axes_ = std::make_shared<Var>();
91   MS_CHECK_TRUE_RET(mean1_axes_ != nullptr, false);
92   mean2_ = std::make_shared<Var>();
93   MS_CHECK_TRUE_RET(mean2_ != nullptr, false);
94   mean2_axes_ = std::make_shared<Var>();
95   MS_CHECK_TRUE_RET(mean2_axes_ != nullptr, false);
96   gamma_ = std::make_shared<Var>();
97   MS_CHECK_TRUE_RET(gamma_ != nullptr, false);
98   beta_ = std::make_shared<Var>();
99   MS_CHECK_TRUE_RET(beta_ != nullptr, false);
100   epsilon_ = std::make_shared<Var>();
101   MS_CHECK_TRUE_RET(epsilon_ != nullptr, false);
102   return true;
103 }
104 
CreateNormNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const schema::PrimitiveType type,float epsilon,int begin_norm_axis,int begin_params_axis) const105 CNodePtr NormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
106                                     const schema::PrimitiveType type, float epsilon, int begin_norm_axis,
107                                     int begin_params_axis) const {
108   MS_ASSERT(func_graph != nullptr);
109   MS_ASSERT(equiv != nullptr);
110   PrimitiveCPtr primitive = nullptr;
111   if (type == schema::PrimitiveType_LayerNormFusion) {
112     auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>();
113     MS_CHECK_TRUE_RET(layer_norm_primitive != nullptr, nullptr);
114     layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true);
115     primitive = layer_norm_primitive;
116   } else if (type == schema::PrimitiveType_InstanceNorm) {
117     auto instance_norm_primitive = std::make_shared<ops::InstanceNorm>();
118     MS_CHECK_TRUE_RET(instance_norm_primitive != nullptr, nullptr);
119     instance_norm_primitive->Init(epsilon);
120     primitive = instance_norm_primitive;
121   } else {
122     return nullptr;
123   }
124   auto value_node = NewValueNode(primitive);
125   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
126   std::vector<AnfNodePtr> new_node_inputs = {value_node};
127   auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
128   MS_ASSERT(input_node != nullptr);
129   new_node_inputs.push_back(input_node);
130   auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
131   MS_ASSERT(gamma_node != nullptr);
132   new_node_inputs.push_back(gamma_node);
133   auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
134   MS_ASSERT(beta_node != nullptr);
135   new_node_inputs.push_back(beta_node);
136   auto new_node = func_graph->NewCNode(new_node_inputs);
137   return new_node;
138 }
139 
GetNormTypeAndAxis(const FuncGraphPtr & func_graph,const CNodePtr & input_cnode,const std::vector<int> & mean_axes,const std::vector<int> & params_shape,schema::PrimitiveType * type,int * begin_norm_axis,int * begin_params_axis) const140 bool NormFusion::GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
141                                     const std::vector<int> &mean_axes, const std::vector<int> &params_shape,
142                                     schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const {
143   MS_ASSERT(func_graph != nullptr);
144   MS_ASSERT(input_cnode != nullptr);
145   MS_ASSERT(type != nullptr);
146   MS_ASSERT(begin_norm_axis != nullptr);
147   MS_ASSERT(begin_params_axis != nullptr);
148   auto abstract = input_cnode->abstract();
149   if (abstract == nullptr) {
150     MS_LOG(DEBUG) << "abstract of input is nullptr";
151     return false;
152   }
153   ShapeVector shape;
154   if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
155     MS_LOG(ERROR) << "fetch shape failed.";
156     return false;
157   }
158   int shape_size = static_cast<int>(shape.size());
159   if (shape.empty()) {
160     auto shape_size_map = ShapeSizeInfer(func_graph);
161     if (shape_size_map.find(input_cnode->fullname_with_scope()) != shape_size_map.end()) {
162       shape_size = shape_size_map[input_cnode->fullname_with_scope()];
163     }
164   }
165 
166   for (size_t i = 1; i < mean_axes.size(); ++i) {
167     if (mean_axes[i] != mean_axes[i - 1] + 1) {
168       MS_LOG(DEBUG) << "mean axes is not continuous";
169       return false;
170     }
171   }
172   // shape input has 4 dim && mean input has 2 dim and mean is in [1, 2 ,...]
173   if (shape_size == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) {
174     if (params_shape.size() == 1 && params_shape.back() == shape.back()) {
175       *type = schema::PrimitiveType_InstanceNorm;
176       return true;
177     }
178   }
179   if (mean_axes.back() >= 0 && mean_axes.back() + 1 != shape_size) {
180     MS_LOG(DEBUG) << "mean node is not reduce to last axis.";
181     return false;
182   }
183 
184   // there is no need to check params_shape
185   *begin_norm_axis = mean_axes.front();
186   if (*begin_norm_axis >= 0) {
187     *begin_params_axis = shape_size - static_cast<int>(params_shape.size());
188     if (*begin_params_axis < 0) {
189       MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse";
190       return false;
191     }
192   } else {
193     *begin_params_axis = -static_cast<int>(params_shape.size());
194   }
195 
196   *type = schema::PrimitiveType_LayerNormFusion;
197   return true;
198 }
199 
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,schema::PrimitiveType * type,float * epsilon,int * begin_norm_axis,int * begin_params_axis) const200 bool NormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type,
201                               float *epsilon, int *begin_norm_axis, int *begin_params_axis) const {
202   MS_ASSERT(equiv != nullptr);
203   MS_ASSERT(epsilon != nullptr);
204   MS_ASSERT(type != nullptr);
205   MS_ASSERT(begin_norm_axis != nullptr);
206   MS_ASSERT(begin_params_axis != nullptr);
207   // beta
208   auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
209   MS_ASSERT(beta_node != nullptr);
210   if (!beta_node->isa<Parameter>()) {
211     return false;
212   }
213   auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
214   MS_CHECK_TRUE_RET(beta_param != nullptr, false);
215   auto beta_tensor = beta_param->cast<tensor::TensorPtr>();
216   MS_CHECK_TRUE_RET(beta_tensor != nullptr, false);
217   std::vector<int> beta_shape;
218   std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape),
219                  [](int64_t val) { return static_cast<int>(val); });
220   // gamma
221   auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
222   MS_ASSERT(gamma_node != nullptr);
223   if (!gamma_node->isa<Parameter>()) {
224     return false;
225   }
226   auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
227   MS_CHECK_TRUE_RET(gamma_param != nullptr, false);
228   auto gamma_tensor = gamma_param->cast<tensor::TensorPtr>();
229   MS_CHECK_TRUE_RET(gamma_tensor != nullptr, false);
230   std::vector<int> gamma_shape;
231   std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape),
232                  [](int64_t val) { return static_cast<int>(val); });
233   // epsilon
234   auto epsilon_node = utils::cast<AnfNodePtr>((*equiv)[epsilon_]);
235   MS_ASSERT(epsilon_node != nullptr);
236   if (!epsilon_node->isa<Parameter>()) {
237     return false;
238   }
239   auto epsilon_param = epsilon_node->cast<ParameterPtr>()->default_param();
240   MS_CHECK_TRUE_RET(epsilon_param != nullptr, false);
241   auto epsilon_tensor = epsilon_param->cast<tensor::TensorPtr>();
242   MS_CHECK_TRUE_RET(epsilon_tensor != nullptr, false);
243   auto epsilon_shape = epsilon_tensor->shape();
244   // mean2
245   std::vector<int> mean2_axes;
246   if (!IsReduceNode(equiv, mean2_, mean2_axes_, &mean2_axes)) {
247     return false;
248   }
249   // mean1
250   std::vector<int> mean1_axes;
251   if (!IsReduceNode(equiv, mean1_, mean1_axes_, &mean1_axes)) {
252     return false;
253   }
254   auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
255   MS_ASSERT(input_node != nullptr);
256   if (!utils::isa<CNodePtr>(input_node)) {
257     return false;
258   }
259   auto input_cnode = input_node->cast<CNodePtr>();
260   if (mean1_axes != mean2_axes) {
261     return false;
262   }
263   if (gamma_shape != beta_shape) {
264     return false;
265   }
266   if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
267     MS_CHECK_TRUE_RET(epsilon_tensor->data_c() != nullptr, false);
268     auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->data_c());
269     *epsilon = epsilon_data[0];
270   } else {
271     return false;
272   }
273   return GetNormTypeAndAxis(func_graph, input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis);
274 }
275 
276 namespace {
CommonShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)277 int CommonShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
278   MS_ASSERT(in_shape_size.size() > 0);
279   return in_shape_size.at(0);
280 }
281 
ExpandDimsShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)282 int ExpandDimsShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
283   MS_ASSERT(in_shape_size.size() > 0);
284   return in_shape_size.at(0) + 1;
285 }
286 
StridedSliceShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)287 int StridedSliceShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
288   MS_ASSERT(in_shape_size.size() > 0);
289   MS_ASSERT(primitive.value.AsStridedSlice() != nullptr);
290   auto new_axis_mask = static_cast<size_t>(primitive.value.AsStridedSlice()->new_axis_mask);
291   auto add_dims = 0;
292   while (new_axis_mask != 0) {
293     new_axis_mask = (new_axis_mask - 1) & new_axis_mask;
294     add_dims++;
295   }
296   return in_shape_size.at(0) + add_dims;
297 }
298 
MatMulShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)299 int MatMulShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
300   MS_ASSERT(in_shape_size.size() > 1);
301   return in_shape_size[0];
302 }
303 
ReShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)304 int ReShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
305   MS_ASSERT(in_shape_size.size() > 1);
306   return in_shape_size[1];
307 }
308 
StackSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)309 int StackSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
310   MS_ASSERT(in_shape_size.size() > 1);
311   return std::accumulate(in_shape_size.begin(), in_shape_size.end(), 0);
312 }
313 
SqueezeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)314 int SqueezeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
315   MS_ASSERT(in_shape_size.size() > 0);
316   auto axis = primitive.value.AsSqueeze()->axis;
317   if (axis.empty()) {
318     return 0;
319   }
320   return in_shape_size.at(0) - axis.size();
321 }
322 
OneHotSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)323 int OneHotSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
324   MS_ASSERT(in_shape_size.size() > 0);
325   return in_shape_size.at(0) + 1;
326 }
327 
FillShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)328 int FillShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
329   MS_ASSERT(in_shape_size.size() > 1);
330   return in_shape_size.at(1);
331 }
332 
ShapeOpSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)333 int ShapeOpSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) { return 1; }
334 
BroadcastShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)335 int BroadcastShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
336   MS_ASSERT(in_shape_size.size() > 1);
337   int result = 0;
338   for (auto shape_size : in_shape_size) {
339     result = std::max(result, shape_size);
340   }
341   return result;
342 }
343 }  // namespace
344 
ShapeSizeInfer(const FuncGraphPtr & func_graph) const345 std::map<string, int> NormFusion::ShapeSizeInfer(const FuncGraphPtr &func_graph) const {
346   MS_ASSERT(func_graph != nullptr);
347   std::map<string, int> node_shape_size;
348   std::map<string, std::vector<int>> node_shape;
349   auto node_list = TopoSort(func_graph->get_return());
350   for (auto &node : node_list) {
351     if (!utils::isa<CNodePtr>(node)) {
352       continue;
353     }
354     auto cnode = node->cast<CNodePtr>();
355     auto prim_t = lite::GetPrimitiveT(cnode->input(0));
356     if (prim_t == nullptr) {
357       continue;
358     }
359     auto prim_type = prim_t->value.type;
360     auto shape_size_infer_iter = shape_size_infer_registry_.find(prim_type);
361     if (shape_size_infer_iter == shape_size_infer_registry_.end()) {
362       continue;
363     }
364 
365     // specific op infer shape
366     if (prim_type == schema::PrimitiveType_Shape) {
367       tensor::TensorPtr tensor_info;
368       auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, 1);
369       if (ret == RET_OK) {
370         node_shape[cnode->fullname_with_scope()] = {static_cast<int>(tensor_info->shape().size())};
371       } else if (node_shape_size.find(cnode->input(1)->fullname_with_scope()) != node_shape_size.end()) {
372         node_shape[cnode->fullname_with_scope()] = {node_shape_size[cnode->input(1)->fullname_with_scope()]};
373       }
374     } else if (prim_type == schema::PrimitiveType_StridedSlice) {
375       node_shape[cnode->fullname_with_scope()] = node_shape[cnode->input(1)->fullname_with_scope()];
376     } else if (prim_type == schema::PrimitiveType_Stack) {
377       auto shape = node_shape[cnode->input(1)->fullname_with_scope()];
378       shape.insert(shape.begin(), cnode->inputs().size() - 1);
379       node_shape[cnode->fullname_with_scope()] = shape;
380     }
381 
382     // Get in node shape size
383     std::vector<int> in_shape_sizes;
384     for (size_t i = 1; i < cnode->inputs().size(); i++) {
385       int in_shape_size = 0;
386       if (utils::isa<CNodePtr>(cnode->input(i))) {
387         in_shape_size = node_shape_size[cnode->input(i)->fullname_with_scope()];
388         // second input of reshape is shape
389         if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT &&
390             node_shape.find(cnode->input(i)->fullname_with_scope()) != node_shape.end()) {
391           in_shape_size = node_shape[cnode->input(i)->fullname_with_scope()].at(0);
392         }
393       } else {
394         tensor::TensorPtr tensor_info;
395         auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, i);
396         if (ret == RET_OK) {
397           in_shape_size = tensor_info->shape().size();
398           // second input of reshape is shape
399           if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT) {
400             in_shape_size = tensor_info->shape().at(0);
401           }
402         }
403       }
404       in_shape_sizes.emplace_back(in_shape_size);
405     }
406     // Cal shape size infer function
407     auto shape_size_infer_func = shape_size_infer_iter->second;
408     auto shape_size = shape_size_infer_func(in_shape_sizes, *prim_t);
409     // Update node shape size map
410     node_shape_size[cnode->fullname_with_scope()] = shape_size;
411   }
412   return node_shape_size;
413 }
414 
InitShapeSizeInferFuncMap()415 void NormFusion::InitShapeSizeInferFuncMap() {
416   if (!shape_size_infer_registry_.empty()) {
417     return;
418   }
419   shape_size_infer_registry_[schema::PrimitiveType_Activation] = CommonShapeSizeInfer;
420   shape_size_infer_registry_[schema::PrimitiveType_AddFusion] = BroadcastShapeSizeInfer;
421   shape_size_infer_registry_[schema::PrimitiveType_BiasAdd] = CommonShapeSizeInfer;
422   shape_size_infer_registry_[schema::PrimitiveType_Stack] = StackSizeInfer;
423   shape_size_infer_registry_[schema::PrimitiveType_Cast] = CommonShapeSizeInfer;
424   shape_size_infer_registry_[schema::PrimitiveType_Concat] = CommonShapeSizeInfer;
425   shape_size_infer_registry_[schema::PrimitiveType_ExpandDims] = ExpandDimsShapeSizeInfer;
426   shape_size_infer_registry_[schema::PrimitiveType_Fill] = FillShapeSizeInfer;
427   shape_size_infer_registry_[schema::PrimitiveType_LayerNormFusion] = CommonShapeSizeInfer;
428   shape_size_infer_registry_[schema::PrimitiveType_MatMul] = MatMulShapeSizeInfer;
429   shape_size_infer_registry_[schema::PrimitiveType_MulFusion] = BroadcastShapeSizeInfer;
430   shape_size_infer_registry_[schema::PrimitiveType_OneHot] = OneHotSizeInfer;
431   shape_size_infer_registry_[schema::PrimitiveType_ReduceFusion] = CommonShapeSizeInfer;
432   shape_size_infer_registry_[schema::PrimitiveType_Reshape] = ReShapeSizeInfer;
433   shape_size_infer_registry_[schema::PrimitiveType_Shape] = ShapeOpSizeInfer;
434   shape_size_infer_registry_[schema::PrimitiveType_SliceFusion] = CommonShapeSizeInfer;
435   shape_size_infer_registry_[schema::PrimitiveType_Softmax] = CommonShapeSizeInfer;
436   shape_size_infer_registry_[schema::PrimitiveType_Squeeze] = SqueezeSizeInfer;
437   shape_size_infer_registry_[schema::PrimitiveType_StridedSlice] = StridedSliceShapeSizeInfer;
438   shape_size_infer_registry_[schema::PrimitiveType_Transpose] = CommonShapeSizeInfer;
439   shape_size_infer_registry_[schema::PrimitiveType_TileFusion] = CommonShapeSizeInfer;
440   shape_size_infer_registry_[schema::PrimitiveType_SquaredDifference] = CommonShapeSizeInfer;
441   shape_size_infer_registry_[schema::PrimitiveType_Rsqrt] = CommonShapeSizeInfer;
442   shape_size_infer_registry_[schema::PrimitiveType_SubFusion] = BroadcastShapeSizeInfer;
443   shape_size_infer_registry_[schema::PrimitiveType_PadFusion] = CommonShapeSizeInfer;
444   shape_size_infer_registry_[schema::PrimitiveType_PowFusion] = CommonShapeSizeInfer;
445 }
446 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const447 const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
448                                      const EquivPtr &equiv) const {
449   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
450     MS_LOG(ERROR) << "input param is nullptr, do norm fusion failed.";
451     return nullptr;
452   }
453   if (!utils::isa<CNodePtr>(node)) {
454     return nullptr;
455   }
456   auto add2_cnode = node->cast<CNodePtr>();
457   if (IsMarkedTrainOp(add2_cnode)) {
458     return nullptr;
459   }
460   float epsilon = 0.0f;
461   int begin_norm_axis = 0;
462   int begin_params_axis = 0;
463   schema::PrimitiveType type = schema::PrimitiveType_NONE;
464   if (!CheckPattern(func_graph, equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) {
465     return nullptr;
466   }
467   auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis);
468   if (norm_cnode == nullptr) {
469     MS_LOG(DEBUG) << "create norm cnode failed";
470     return nullptr;
471   }
472   MS_CHECK_TRUE_RET(add2_cnode->abstract() != nullptr, nullptr);
473   norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
474   if (type == schema::PrimitiveType_LayerNormFusion) {
475     norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
476     MS_LOG(DEBUG) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
477   } else if (type == schema::PrimitiveType_InstanceNorm) {
478     norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope());
479     MS_LOG(DEBUG) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
480   }
481   return norm_cnode;
482 }
483 
DefinePattern() const484 const BaseRef TfNormFusion::DefinePattern() const {
485   if (!Init()) {
486     MS_LOG(ERROR) << "initial member failed.";
487     return {};
488   }
489   VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
490   auto is_squared_diffference = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquaredDifference>);
491   MS_CHECK_TRUE_RET(is_squared_diffference != nullptr, {});
492   VectorRef squared_diffference1_ref = VectorRef({is_squared_diffference, input_, mean1_ref});
493   VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_});
494   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
495   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
496   VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
497   auto is_rsqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRsqrt>);
498   MS_CHECK_TRUE_RET(is_rsqrt != nullptr, {});
499   VectorRef rsqrt1_ref = VectorRef({is_rsqrt, add1_ref});
500   auto is_mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
501   MS_CHECK_TRUE_RET(is_mul2 != nullptr, {});
502   VectorRef mul2_ref = VectorRef({is_mul2, rsqrt1_ref, gamma_});
503   auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
504   MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
505   VectorRef mul1_ref = VectorRef({is_mul1, input_, mul2_ref});
506   auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
507   MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
508   VectorRef mul3_ref = VectorRef({is_mul3, mean1_ref, mul2_ref});
509   auto is_sub = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
510   MS_CHECK_TRUE_RET(is_sub != nullptr, {});
511   VectorRef sub1_ref = VectorRef({is_sub, beta_, mul3_ref});
512   auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
513   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
514   VectorRef add2_ref = VectorRef({is_add2, mul1_ref, sub1_ref});
515   return add2_ref;
516 }
517 
DefinePattern() const518 const BaseRef OnnxLayerNormFusion::DefinePattern() const {
519   if (!Init()) {
520     MS_LOG(ERROR) << "initial member failed.";
521     return {};
522   }
523   VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
524   auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
525   MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
526   VectorRef sub1_ref = VectorRef({is_sub1, input_, mean1_ref});
527   auto is_sub2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
528   MS_CHECK_TRUE_RET(is_sub2 != nullptr, {});
529   VectorRef sub2_ref = VectorRef({is_sub2, input_, mean1_ref});
530   auto is_pow = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>);
531   MS_CHECK_TRUE_RET(is_pow != nullptr, {});
532   auto is_var = std::make_shared<Var>();
533   MS_CHECK_TRUE_RET(is_var != nullptr, {});
534   VectorRef pow_ref = VectorRef({is_pow, sub2_ref, is_var});
535   VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
536   auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
537   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
538   VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
539   auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
540   MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
541   VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
542   auto is_div = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
543   MS_CHECK_TRUE_RET(is_div != nullptr, {});
544   VectorRef div_ref = VectorRef({is_div, sub1_ref, sqrt_ref});
545   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
546   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
547   VectorRef mul_ref = VectorRef({is_mul, gamma_, div_ref});
548   auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
549   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
550   VectorRef add2_ref = VectorRef({is_add2, mul_ref, beta_});
551   return add2_ref;
552 }
553 }  // namespace opt
554 }  // namespace mindspore
555