1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/fusion/norm_fusion.h"
17 #include <algorithm>
18 #include <memory>
19 #include "ops/fusion/layer_norm_fusion.h"
20 #include "ops/fusion/reduce_fusion.h"
21 #include "mindspore/core/ops/instance_norm.h"
22 #include "utils/utils.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "securec/include/securec.h"
25 #include "nnacl/op_base.h"
26 #include "src/ops/ops_utils.h"
27
28 namespace mindspore {
29 namespace opt {
30 namespace {
GetReduceAxes(const BaseRef & n,std::vector<int> * axes)31 STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
32 MS_ASSERT(axes != nullptr);
33 if (utils::isa<ParameterPtr>(n)) {
34 auto axes_param = utils::cast<ParameterPtr>(n);
35 if (!axes_param->has_default() || axes_param->default_param() == nullptr) {
36 return lite::RET_NOT_SUPPORT;
37 }
38 auto axes_value = axes_param->default_param()->cast<tensor::TensorPtr>();
39 if (axes_value == nullptr) {
40 return lite::RET_ERROR;
41 }
42 if (axes_value->data_type() != kNumberTypeInt && axes_value->data_type() != kNumberTypeInt32) {
43 MS_LOG(ERROR) << "reduce's axes should be integer, now is " << axes_value->data_type();
44 return lite::RET_ERROR;
45 }
46 if (axes_value->data_c() == nullptr) {
47 return lite::RET_ERROR;
48 }
49 if (axes_value->shape().size() > 1) {
50 return lite::RET_ERROR;
51 }
52 axes->resize(1);
53 if (!axes_value->shape().empty()) {
54 MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
55 axes->resize(static_cast<size_t>(axes_value->shape()[0]));
56 }
57 if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
58 return lite::RET_OK;
59 }
60 }
61 if (utils::isa<ValueNodePtr>(n)) {
62 auto axes_value_node = utils::cast<ValueNodePtr>(n);
63 *axes = CastToInt(axes_value_node->value());
64 return lite::RET_OK;
65 }
66 return lite::RET_ERROR;
67 }
68
IsReduceNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)69 bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes, std::vector<int> *axes) {
70 MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
71 auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
72 MS_ASSERT(reduce_value != nullptr);
73 auto mean2_primitive = GetValueNode<std::shared_ptr<ops::ReduceFusion>>(reduce_value);
74 if (mean2_primitive == nullptr || mean2_primitive->GetAttr(ops::kMode) == nullptr ||
75 mean2_primitive->get_mode() != mindspore::Reduce_Mean) {
76 return false;
77 }
78 if (GetReduceAxes((*equiv)[input_axes], axes) != lite::RET_OK) {
79 return false;
80 }
81 return true;
82 }
83 } // namespace
84
Init() const85 bool NormFusion::Init() const {
86 input_ = std::make_shared<Var>();
87 MS_CHECK_TRUE_RET(input_ != nullptr, false);
88 mean1_ = std::make_shared<Var>();
89 MS_CHECK_TRUE_RET(mean1_ != nullptr, false);
90 mean1_axes_ = std::make_shared<Var>();
91 MS_CHECK_TRUE_RET(mean1_axes_ != nullptr, false);
92 mean2_ = std::make_shared<Var>();
93 MS_CHECK_TRUE_RET(mean2_ != nullptr, false);
94 mean2_axes_ = std::make_shared<Var>();
95 MS_CHECK_TRUE_RET(mean2_axes_ != nullptr, false);
96 gamma_ = std::make_shared<Var>();
97 MS_CHECK_TRUE_RET(gamma_ != nullptr, false);
98 beta_ = std::make_shared<Var>();
99 MS_CHECK_TRUE_RET(beta_ != nullptr, false);
100 epsilon_ = std::make_shared<Var>();
101 MS_CHECK_TRUE_RET(epsilon_ != nullptr, false);
102 return true;
103 }
104
CreateNormNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const schema::PrimitiveType type,float epsilon,int begin_norm_axis,int begin_params_axis) const105 CNodePtr NormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
106 const schema::PrimitiveType type, float epsilon, int begin_norm_axis,
107 int begin_params_axis) const {
108 MS_ASSERT(func_graph != nullptr);
109 MS_ASSERT(equiv != nullptr);
110 PrimitiveCPtr primitive = nullptr;
111 if (type == schema::PrimitiveType_LayerNormFusion) {
112 auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>();
113 MS_CHECK_TRUE_RET(layer_norm_primitive != nullptr, nullptr);
114 layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true);
115 primitive = layer_norm_primitive;
116 } else if (type == schema::PrimitiveType_InstanceNorm) {
117 auto instance_norm_primitive = std::make_shared<ops::InstanceNorm>();
118 MS_CHECK_TRUE_RET(instance_norm_primitive != nullptr, nullptr);
119 instance_norm_primitive->Init(epsilon);
120 primitive = instance_norm_primitive;
121 } else {
122 return nullptr;
123 }
124 auto value_node = NewValueNode(primitive);
125 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
126 std::vector<AnfNodePtr> new_node_inputs = {value_node};
127 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
128 MS_ASSERT(input_node != nullptr);
129 new_node_inputs.push_back(input_node);
130 auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
131 MS_ASSERT(gamma_node != nullptr);
132 new_node_inputs.push_back(gamma_node);
133 auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
134 MS_ASSERT(beta_node != nullptr);
135 new_node_inputs.push_back(beta_node);
136 auto new_node = func_graph->NewCNode(new_node_inputs);
137 return new_node;
138 }
139
GetNormTypeAndAxis(const FuncGraphPtr & func_graph,const CNodePtr & input_cnode,const std::vector<int> & mean_axes,const std::vector<int> & params_shape,schema::PrimitiveType * type,int * begin_norm_axis,int * begin_params_axis) const140 bool NormFusion::GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
141 const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape,
142 schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const {
143 MS_ASSERT(func_graph != nullptr);
144 MS_ASSERT(input_cnode != nullptr);
145 MS_ASSERT(type != nullptr);
146 MS_ASSERT(begin_norm_axis != nullptr);
147 MS_ASSERT(begin_params_axis != nullptr);
148 auto abstract = input_cnode->abstract();
149 if (abstract == nullptr) {
150 MS_LOG(DEBUG) << "abstract of input is nullptr";
151 return false;
152 }
153 ShapeVector shape;
154 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
155 MS_LOG(ERROR) << "fetch shape failed.";
156 return false;
157 }
158 int shape_size = static_cast<int>(shape.size());
159 if (shape.empty()) {
160 auto shape_size_map = ShapeSizeInfer(func_graph);
161 if (shape_size_map.find(input_cnode->fullname_with_scope()) != shape_size_map.end()) {
162 shape_size = shape_size_map[input_cnode->fullname_with_scope()];
163 }
164 }
165
166 for (size_t i = 1; i < mean_axes.size(); ++i) {
167 if (mean_axes[i] != mean_axes[i - 1] + 1) {
168 MS_LOG(DEBUG) << "mean axes is not continuous";
169 return false;
170 }
171 }
172 // shape input has 4 dim && mean input has 2 dim and mean is in [1, 2 ,...]
173 if (shape_size == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) {
174 if (params_shape.size() == 1 && params_shape.back() == shape.back()) {
175 *type = schema::PrimitiveType_InstanceNorm;
176 return true;
177 }
178 }
179 if (mean_axes.back() >= 0 && mean_axes.back() + 1 != shape_size) {
180 MS_LOG(DEBUG) << "mean node is not reduce to last axis.";
181 return false;
182 }
183
184 // there is no need to check params_shape
185 *begin_norm_axis = mean_axes.front();
186 if (*begin_norm_axis >= 0) {
187 *begin_params_axis = shape_size - static_cast<int>(params_shape.size());
188 if (*begin_params_axis < 0) {
189 MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse";
190 return false;
191 }
192 } else {
193 *begin_params_axis = -static_cast<int>(params_shape.size());
194 }
195
196 *type = schema::PrimitiveType_LayerNormFusion;
197 return true;
198 }
199
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,schema::PrimitiveType * type,float * epsilon,int * begin_norm_axis,int * begin_params_axis) const200 bool NormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type,
201 float *epsilon, int *begin_norm_axis, int *begin_params_axis) const {
202 MS_ASSERT(equiv != nullptr);
203 MS_ASSERT(epsilon != nullptr);
204 MS_ASSERT(type != nullptr);
205 MS_ASSERT(begin_norm_axis != nullptr);
206 MS_ASSERT(begin_params_axis != nullptr);
207 // beta
208 auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
209 MS_ASSERT(beta_node != nullptr);
210 if (!beta_node->isa<Parameter>()) {
211 return false;
212 }
213 auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
214 MS_CHECK_TRUE_RET(beta_param != nullptr, false);
215 auto beta_tensor = beta_param->cast<tensor::TensorPtr>();
216 MS_CHECK_TRUE_RET(beta_tensor != nullptr, false);
217 std::vector<int> beta_shape;
218 std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape),
219 [](int64_t val) { return static_cast<int>(val); });
220 // gamma
221 auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
222 MS_ASSERT(gamma_node != nullptr);
223 if (!gamma_node->isa<Parameter>()) {
224 return false;
225 }
226 auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
227 MS_CHECK_TRUE_RET(gamma_param != nullptr, false);
228 auto gamma_tensor = gamma_param->cast<tensor::TensorPtr>();
229 MS_CHECK_TRUE_RET(gamma_tensor != nullptr, false);
230 std::vector<int> gamma_shape;
231 std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape),
232 [](int64_t val) { return static_cast<int>(val); });
233 // epsilon
234 auto epsilon_node = utils::cast<AnfNodePtr>((*equiv)[epsilon_]);
235 MS_ASSERT(epsilon_node != nullptr);
236 if (!epsilon_node->isa<Parameter>()) {
237 return false;
238 }
239 auto epsilon_param = epsilon_node->cast<ParameterPtr>()->default_param();
240 MS_CHECK_TRUE_RET(epsilon_param != nullptr, false);
241 auto epsilon_tensor = epsilon_param->cast<tensor::TensorPtr>();
242 MS_CHECK_TRUE_RET(epsilon_tensor != nullptr, false);
243 auto epsilon_shape = epsilon_tensor->shape();
244 // mean2
245 std::vector<int> mean2_axes;
246 if (!IsReduceNode(equiv, mean2_, mean2_axes_, &mean2_axes)) {
247 return false;
248 }
249 // mean1
250 std::vector<int> mean1_axes;
251 if (!IsReduceNode(equiv, mean1_, mean1_axes_, &mean1_axes)) {
252 return false;
253 }
254 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
255 MS_ASSERT(input_node != nullptr);
256 if (!utils::isa<CNodePtr>(input_node)) {
257 return false;
258 }
259 auto input_cnode = input_node->cast<CNodePtr>();
260 if (mean1_axes != mean2_axes) {
261 return false;
262 }
263 if (gamma_shape != beta_shape) {
264 return false;
265 }
266 if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
267 MS_CHECK_TRUE_RET(epsilon_tensor->data_c() != nullptr, false);
268 auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->data_c());
269 *epsilon = epsilon_data[0];
270 } else {
271 return false;
272 }
273 return GetNormTypeAndAxis(func_graph, input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis);
274 }
275
276 namespace {
CommonShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)277 int CommonShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
278 MS_ASSERT(in_shape_size.size() > 0);
279 return in_shape_size.at(0);
280 }
281
ExpandDimsShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)282 int ExpandDimsShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
283 MS_ASSERT(in_shape_size.size() > 0);
284 return in_shape_size.at(0) + 1;
285 }
286
StridedSliceShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)287 int StridedSliceShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
288 MS_ASSERT(in_shape_size.size() > 0);
289 MS_ASSERT(primitive.value.AsStridedSlice() != nullptr);
290 auto new_axis_mask = static_cast<size_t>(primitive.value.AsStridedSlice()->new_axis_mask);
291 auto add_dims = 0;
292 while (new_axis_mask != 0) {
293 new_axis_mask = (new_axis_mask - 1) & new_axis_mask;
294 add_dims++;
295 }
296 return in_shape_size.at(0) + add_dims;
297 }
298
MatMulShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)299 int MatMulShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
300 MS_ASSERT(in_shape_size.size() > 1);
301 return in_shape_size[0];
302 }
303
ReShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)304 int ReShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
305 MS_ASSERT(in_shape_size.size() > 1);
306 return in_shape_size[1];
307 }
308
StackSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)309 int StackSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
310 MS_ASSERT(in_shape_size.size() > 1);
311 return std::accumulate(in_shape_size.begin(), in_shape_size.end(), 0);
312 }
313
SqueezeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)314 int SqueezeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
315 MS_ASSERT(in_shape_size.size() > 0);
316 auto axis = primitive.value.AsSqueeze()->axis;
317 if (axis.empty()) {
318 return 0;
319 }
320 return in_shape_size.at(0) - axis.size();
321 }
322
OneHotSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)323 int OneHotSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
324 MS_ASSERT(in_shape_size.size() > 0);
325 return in_shape_size.at(0) + 1;
326 }
327
FillShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)328 int FillShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
329 MS_ASSERT(in_shape_size.size() > 1);
330 return in_shape_size.at(1);
331 }
332
ShapeOpSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)333 int ShapeOpSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) { return 1; }
334
BroadcastShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)335 int BroadcastShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
336 MS_ASSERT(in_shape_size.size() > 1);
337 int result = 0;
338 for (auto shape_size : in_shape_size) {
339 result = std::max(result, shape_size);
340 }
341 return result;
342 }
343 } // namespace
344
ShapeSizeInfer(const FuncGraphPtr & func_graph) const345 std::map<string, int> NormFusion::ShapeSizeInfer(const FuncGraphPtr &func_graph) const {
346 MS_ASSERT(func_graph != nullptr);
347 std::map<string, int> node_shape_size;
348 std::map<string, std::vector<int>> node_shape;
349 auto node_list = TopoSort(func_graph->get_return());
350 for (auto &node : node_list) {
351 if (!utils::isa<CNodePtr>(node)) {
352 continue;
353 }
354 auto cnode = node->cast<CNodePtr>();
355 auto prim_t = lite::GetPrimitiveT(cnode->input(0));
356 if (prim_t == nullptr) {
357 continue;
358 }
359 auto prim_type = prim_t->value.type;
360 auto shape_size_infer_iter = shape_size_infer_registry_.find(prim_type);
361 if (shape_size_infer_iter == shape_size_infer_registry_.end()) {
362 continue;
363 }
364
365 // specific op infer shape
366 if (prim_type == schema::PrimitiveType_Shape) {
367 tensor::TensorPtr tensor_info;
368 auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, 1);
369 if (ret == RET_OK) {
370 node_shape[cnode->fullname_with_scope()] = {static_cast<int>(tensor_info->shape().size())};
371 } else if (node_shape_size.find(cnode->input(1)->fullname_with_scope()) != node_shape_size.end()) {
372 node_shape[cnode->fullname_with_scope()] = {node_shape_size[cnode->input(1)->fullname_with_scope()]};
373 }
374 } else if (prim_type == schema::PrimitiveType_StridedSlice) {
375 node_shape[cnode->fullname_with_scope()] = node_shape[cnode->input(1)->fullname_with_scope()];
376 } else if (prim_type == schema::PrimitiveType_Stack) {
377 auto shape = node_shape[cnode->input(1)->fullname_with_scope()];
378 shape.insert(shape.begin(), cnode->inputs().size() - 1);
379 node_shape[cnode->fullname_with_scope()] = shape;
380 }
381
382 // Get in node shape size
383 std::vector<int> in_shape_sizes;
384 for (size_t i = 1; i < cnode->inputs().size(); i++) {
385 int in_shape_size = 0;
386 if (utils::isa<CNodePtr>(cnode->input(i))) {
387 in_shape_size = node_shape_size[cnode->input(i)->fullname_with_scope()];
388 // second input of reshape is shape
389 if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT &&
390 node_shape.find(cnode->input(i)->fullname_with_scope()) != node_shape.end()) {
391 in_shape_size = node_shape[cnode->input(i)->fullname_with_scope()].at(0);
392 }
393 } else {
394 tensor::TensorPtr tensor_info;
395 auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, i);
396 if (ret == RET_OK) {
397 in_shape_size = tensor_info->shape().size();
398 // second input of reshape is shape
399 if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT) {
400 in_shape_size = tensor_info->shape().at(0);
401 }
402 }
403 }
404 in_shape_sizes.emplace_back(in_shape_size);
405 }
406 // Cal shape size infer function
407 auto shape_size_infer_func = shape_size_infer_iter->second;
408 auto shape_size = shape_size_infer_func(in_shape_sizes, *prim_t);
409 // Update node shape size map
410 node_shape_size[cnode->fullname_with_scope()] = shape_size;
411 }
412 return node_shape_size;
413 }
414
InitShapeSizeInferFuncMap()415 void NormFusion::InitShapeSizeInferFuncMap() {
416 if (!shape_size_infer_registry_.empty()) {
417 return;
418 }
419 shape_size_infer_registry_[schema::PrimitiveType_Activation] = CommonShapeSizeInfer;
420 shape_size_infer_registry_[schema::PrimitiveType_AddFusion] = BroadcastShapeSizeInfer;
421 shape_size_infer_registry_[schema::PrimitiveType_BiasAdd] = CommonShapeSizeInfer;
422 shape_size_infer_registry_[schema::PrimitiveType_Stack] = StackSizeInfer;
423 shape_size_infer_registry_[schema::PrimitiveType_Cast] = CommonShapeSizeInfer;
424 shape_size_infer_registry_[schema::PrimitiveType_Concat] = CommonShapeSizeInfer;
425 shape_size_infer_registry_[schema::PrimitiveType_ExpandDims] = ExpandDimsShapeSizeInfer;
426 shape_size_infer_registry_[schema::PrimitiveType_Fill] = FillShapeSizeInfer;
427 shape_size_infer_registry_[schema::PrimitiveType_LayerNormFusion] = CommonShapeSizeInfer;
428 shape_size_infer_registry_[schema::PrimitiveType_MatMul] = MatMulShapeSizeInfer;
429 shape_size_infer_registry_[schema::PrimitiveType_MulFusion] = BroadcastShapeSizeInfer;
430 shape_size_infer_registry_[schema::PrimitiveType_OneHot] = OneHotSizeInfer;
431 shape_size_infer_registry_[schema::PrimitiveType_ReduceFusion] = CommonShapeSizeInfer;
432 shape_size_infer_registry_[schema::PrimitiveType_Reshape] = ReShapeSizeInfer;
433 shape_size_infer_registry_[schema::PrimitiveType_Shape] = ShapeOpSizeInfer;
434 shape_size_infer_registry_[schema::PrimitiveType_SliceFusion] = CommonShapeSizeInfer;
435 shape_size_infer_registry_[schema::PrimitiveType_Softmax] = CommonShapeSizeInfer;
436 shape_size_infer_registry_[schema::PrimitiveType_Squeeze] = SqueezeSizeInfer;
437 shape_size_infer_registry_[schema::PrimitiveType_StridedSlice] = StridedSliceShapeSizeInfer;
438 shape_size_infer_registry_[schema::PrimitiveType_Transpose] = CommonShapeSizeInfer;
439 shape_size_infer_registry_[schema::PrimitiveType_TileFusion] = CommonShapeSizeInfer;
440 shape_size_infer_registry_[schema::PrimitiveType_SquaredDifference] = CommonShapeSizeInfer;
441 shape_size_infer_registry_[schema::PrimitiveType_Rsqrt] = CommonShapeSizeInfer;
442 shape_size_infer_registry_[schema::PrimitiveType_SubFusion] = BroadcastShapeSizeInfer;
443 shape_size_infer_registry_[schema::PrimitiveType_PadFusion] = CommonShapeSizeInfer;
444 shape_size_infer_registry_[schema::PrimitiveType_PowFusion] = CommonShapeSizeInfer;
445 }
446
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const447 const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
448 const EquivPtr &equiv) const {
449 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
450 MS_LOG(ERROR) << "input param is nullptr, do norm fusion failed.";
451 return nullptr;
452 }
453 if (!utils::isa<CNodePtr>(node)) {
454 return nullptr;
455 }
456 auto add2_cnode = node->cast<CNodePtr>();
457 if (IsMarkedTrainOp(add2_cnode)) {
458 return nullptr;
459 }
460 float epsilon = 0.0f;
461 int begin_norm_axis = 0;
462 int begin_params_axis = 0;
463 schema::PrimitiveType type = schema::PrimitiveType_NONE;
464 if (!CheckPattern(func_graph, equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) {
465 return nullptr;
466 }
467 auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis);
468 if (norm_cnode == nullptr) {
469 MS_LOG(DEBUG) << "create norm cnode failed";
470 return nullptr;
471 }
472 MS_CHECK_TRUE_RET(add2_cnode->abstract() != nullptr, nullptr);
473 norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
474 if (type == schema::PrimitiveType_LayerNormFusion) {
475 norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
476 MS_LOG(DEBUG) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
477 } else if (type == schema::PrimitiveType_InstanceNorm) {
478 norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope());
479 MS_LOG(DEBUG) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
480 }
481 return norm_cnode;
482 }
483
DefinePattern() const484 const BaseRef TfNormFusion::DefinePattern() const {
485 if (!Init()) {
486 MS_LOG(ERROR) << "initial member failed.";
487 return {};
488 }
489 VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
490 auto is_squared_diffference = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquaredDifference>);
491 MS_CHECK_TRUE_RET(is_squared_diffference != nullptr, {});
492 VectorRef squared_diffference1_ref = VectorRef({is_squared_diffference, input_, mean1_ref});
493 VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_});
494 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
495 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
496 VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
497 auto is_rsqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRsqrt>);
498 MS_CHECK_TRUE_RET(is_rsqrt != nullptr, {});
499 VectorRef rsqrt1_ref = VectorRef({is_rsqrt, add1_ref});
500 auto is_mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
501 MS_CHECK_TRUE_RET(is_mul2 != nullptr, {});
502 VectorRef mul2_ref = VectorRef({is_mul2, rsqrt1_ref, gamma_});
503 auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
504 MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
505 VectorRef mul1_ref = VectorRef({is_mul1, input_, mul2_ref});
506 auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
507 MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
508 VectorRef mul3_ref = VectorRef({is_mul3, mean1_ref, mul2_ref});
509 auto is_sub = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
510 MS_CHECK_TRUE_RET(is_sub != nullptr, {});
511 VectorRef sub1_ref = VectorRef({is_sub, beta_, mul3_ref});
512 auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
513 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
514 VectorRef add2_ref = VectorRef({is_add2, mul1_ref, sub1_ref});
515 return add2_ref;
516 }
517
DefinePattern() const518 const BaseRef OnnxLayerNormFusion::DefinePattern() const {
519 if (!Init()) {
520 MS_LOG(ERROR) << "initial member failed.";
521 return {};
522 }
523 VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
524 auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
525 MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
526 VectorRef sub1_ref = VectorRef({is_sub1, input_, mean1_ref});
527 auto is_sub2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
528 MS_CHECK_TRUE_RET(is_sub2 != nullptr, {});
529 VectorRef sub2_ref = VectorRef({is_sub2, input_, mean1_ref});
530 auto is_pow = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>);
531 MS_CHECK_TRUE_RET(is_pow != nullptr, {});
532 auto is_var = std::make_shared<Var>();
533 MS_CHECK_TRUE_RET(is_var != nullptr, {});
534 VectorRef pow_ref = VectorRef({is_pow, sub2_ref, is_var});
535 VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
536 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
537 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
538 VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
539 auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
540 MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
541 VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
542 auto is_div = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
543 MS_CHECK_TRUE_RET(is_div != nullptr, {});
544 VectorRef div_ref = VectorRef({is_div, sub1_ref, sqrt_ref});
545 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
546 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
547 VectorRef mul_ref = VectorRef({is_mul, gamma_, div_ref});
548 auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
549 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
550 VectorRef add2_ref = VectorRef({is_add2, mul_ref, beta_});
551 return add2_ref;
552 }
553 } // namespace opt
554 } // namespace mindspore
555