1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/norm_fusion.h"
19 #include <algorithm>
20 #include <memory>
21 #include "mindspore/core/ops/math_ops.h"
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "ops/fusion/layer_norm_fusion.h"
25 #include "ops/fusion/reduce_fusion.h"
26 #include "mindspore/core/ops/instance_norm.h"
27 #include "include/common/utils/utils.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "securec/include/securec.h"
30 #include "nnacl/op_base.h"
31 #include "src/common/ops/anf_utils.h"
32 #include "ops/op_utils.h"
33
34 namespace mindspore {
35 namespace opt {
36 namespace {
GetReduceAxes(const BaseRef & n,std::vector<int> * axes)37 STATUS GetReduceAxes(const BaseRef &n, std::vector<int> *axes) {
38 MS_ASSERT(axes != nullptr);
39 if (utils::isa<ParameterPtr>(n)) {
40 auto axes_param = utils::cast<ParameterPtr>(n);
41 if (!axes_param->has_default() || axes_param->default_param() == nullptr) {
42 return lite::RET_NOT_SUPPORT;
43 }
44 auto axes_value = axes_param->default_param()->cast<tensor::TensorPtr>();
45 if (axes_value == nullptr) {
46 return lite::RET_ERROR;
47 }
48 if (axes_value->data_type() != kNumberTypeInt && axes_value->data_type() != kNumberTypeInt32) {
49 MS_LOG(ERROR) << "reduce's axes should be integer, now is " << axes_value->data_type();
50 return lite::RET_ERROR;
51 }
52 if (axes_value->data_c() == nullptr) {
53 return lite::RET_ERROR;
54 }
55 if (axes_value->shape().size() > 1) {
56 return lite::RET_ERROR;
57 }
58 axes->resize(1);
59 if (!axes_value->shape().empty()) {
60 MS_CHECK_GE(axes_value->shape()[0], 0, lite::RET_ERROR);
61 axes->resize(static_cast<size_t>(axes_value->shape()[0]));
62 }
63 if (memcpy_s(axes->data(), axes->size() * sizeof(int), axes_value->data_c(), axes_value->Size()) == EOK) {
64 return lite::RET_OK;
65 }
66 }
67 if (utils::isa<ValueNodePtr>(n)) {
68 auto axes_value_node = utils::cast<ValueNodePtr>(n);
69 *axes = CastToInt(axes_value_node->value());
70 return lite::RET_OK;
71 }
72 return lite::RET_ERROR;
73 }
74
IsReduceNode(const EquivPtr & equiv,const VarPtr & input_prim,const VarPtr & input_axes,std::vector<int> * axes)75 bool IsReduceNode(const EquivPtr &equiv, const VarPtr &input_prim, const VarPtr &input_axes, std::vector<int> *axes) {
76 MS_ASSERT(equiv != nullptr && input_prim != nullptr && input_axes != nullptr && axes != nullptr);
77 auto reduce_value = utils::cast<AnfNodePtr>((*equiv)[input_prim]);
78 MS_ASSERT(reduce_value != nullptr);
79 auto mean2_primitive = ops::GetOperator<ops::ReduceFusion>(reduce_value);
80 MS_CHECK_TRUE_RET(mean2_primitive != nullptr, false);
81 auto mean2_primitive_c = mean2_primitive->GetPrim();
82 if (mean2_primitive_c->GetAttr(ops::kMode) == nullptr || mean2_primitive->get_mode() != mindspore::Reduce_Mean) {
83 return false;
84 }
85 if (GetReduceAxes((*equiv)[input_axes], axes) != lite::RET_OK) {
86 return false;
87 }
88 return true;
89 }
90 } // namespace
91
Init() const92 bool NormFusion::Init() const {
93 input_ = std::make_shared<Var>();
94 MS_CHECK_TRUE_RET(input_ != nullptr, false);
95 mean1_ = std::make_shared<Var>();
96 MS_CHECK_TRUE_RET(mean1_ != nullptr, false);
97 mean1_axes_ = std::make_shared<Var>();
98 MS_CHECK_TRUE_RET(mean1_axes_ != nullptr, false);
99 mean2_ = std::make_shared<Var>();
100 MS_CHECK_TRUE_RET(mean2_ != nullptr, false);
101 mean2_axes_ = std::make_shared<Var>();
102 MS_CHECK_TRUE_RET(mean2_axes_ != nullptr, false);
103 gamma_ = std::make_shared<Var>();
104 MS_CHECK_TRUE_RET(gamma_ != nullptr, false);
105 beta_ = std::make_shared<Var>();
106 MS_CHECK_TRUE_RET(beta_ != nullptr, false);
107 epsilon_ = std::make_shared<Var>();
108 MS_CHECK_TRUE_RET(epsilon_ != nullptr, false);
109 return true;
110 }
111
CreateNormNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const schema::PrimitiveType type,float epsilon,int begin_norm_axis,int begin_params_axis) const112 CNodePtr NormFusion::CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
113 const schema::PrimitiveType type, float epsilon, int begin_norm_axis,
114 int begin_params_axis) const {
115 MS_ASSERT(func_graph != nullptr);
116 MS_ASSERT(equiv != nullptr);
117 PrimitiveCPtr primitive_c = nullptr;
118 if (type == schema::PrimitiveType_LayerNormFusion) {
119 auto layer_norm_primitive = std::make_shared<ops::LayerNormFusion>();
120 MS_CHECK_TRUE_RET(layer_norm_primitive != nullptr, nullptr);
121 layer_norm_primitive->Init(begin_norm_axis, begin_params_axis, epsilon, true);
122 auto layer_norm_primitive_c = layer_norm_primitive->GetPrim();
123 MS_CHECK_TRUE_RET(layer_norm_primitive_c != nullptr, nullptr);
124 primitive_c = layer_norm_primitive_c;
125 } else if (type == schema::PrimitiveType_InstanceNorm) {
126 auto instance_norm_primitive = std::make_shared<ops::InstanceNorm>();
127 MS_CHECK_TRUE_RET(instance_norm_primitive != nullptr, nullptr);
128 auto instance_norm_primitive_c = instance_norm_primitive->GetPrim();
129 MS_CHECK_TRUE_RET(instance_norm_primitive_c != nullptr, nullptr);
130 instance_norm_primitive->Init(epsilon);
131 primitive_c = instance_norm_primitive_c;
132 } else {
133 return nullptr;
134 }
135 auto value_node = NewValueNode(primitive_c);
136 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
137 std::vector<AnfNodePtr> new_node_inputs = {value_node};
138 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
139 MS_ASSERT(input_node != nullptr);
140 new_node_inputs.push_back(input_node);
141 auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
142 MS_ASSERT(gamma_node != nullptr);
143 new_node_inputs.push_back(gamma_node);
144 auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
145 MS_ASSERT(beta_node != nullptr);
146 new_node_inputs.push_back(beta_node);
147 auto new_node = func_graph->NewCNode(new_node_inputs);
148 return new_node;
149 }
150
GetNormTypeAndAxis(const FuncGraphPtr & func_graph,const CNodePtr & input_cnode,const std::vector<int> & mean_axes,const std::vector<int> & params_shape,schema::PrimitiveType * type,int * begin_norm_axis,int * begin_params_axis) const151 bool NormFusion::GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
152 const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape,
153 schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const {
154 MS_ASSERT(func_graph != nullptr);
155 MS_ASSERT(input_cnode != nullptr);
156 MS_ASSERT(type != nullptr);
157 MS_ASSERT(begin_norm_axis != nullptr);
158 MS_ASSERT(begin_params_axis != nullptr);
159 auto abstract = input_cnode->abstract();
160 if (abstract == nullptr) {
161 MS_LOG(DEBUG) << "abstract of input is nullptr";
162 return false;
163 }
164 ShapeVector shape;
165 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
166 MS_LOG(ERROR) << "fetch shape failed.";
167 return false;
168 }
169 int shape_size = static_cast<int>(shape.size());
170 if (lite::JudgeDynamicShape(shape)) {
171 auto shape_size_map = ShapeSizeInfer(func_graph);
172 if (shape_size_map.find(input_cnode->fullname_with_scope()) != shape_size_map.end()) {
173 shape_size = shape_size_map[input_cnode->fullname_with_scope()];
174 }
175 }
176
177 for (size_t i = 1; i < mean_axes.size(); ++i) {
178 if (mean_axes[i] != mean_axes[i - 1] + 1) {
179 MS_LOG(DEBUG) << "mean axes is not continuous";
180 return false;
181 }
182 }
183 // shape input has 4 dim && mean input has 2 dim and mean is in [1, 2 ,...]
184 if (shape_size == 4 && mean_axes.size() == 2 && mean_axes[0] == 1 && mean_axes[1] == 2) {
185 if (params_shape.size() == 1 && params_shape.back() == shape.back()) {
186 *type = schema::PrimitiveType_InstanceNorm;
187 return true;
188 }
189 }
190 if (mean_axes.back() >= 0 && mean_axes.back() + 1 != shape_size) {
191 MS_LOG(DEBUG) << "mean node is not reduce to last axis.";
192 return false;
193 }
194
195 // there is no need to check params_shape
196 *begin_norm_axis = mean_axes.front();
197 if (*begin_norm_axis >= 0) {
198 *begin_params_axis = shape_size - static_cast<int>(params_shape.size());
199 if (*begin_params_axis < 0) {
200 MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse";
201 return false;
202 }
203 } else {
204 *begin_params_axis = -static_cast<int>(params_shape.size());
205 }
206
207 *type = schema::PrimitiveType_LayerNormFusion;
208 return true;
209 }
210
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,schema::PrimitiveType * type,float * epsilon,int * begin_norm_axis,int * begin_params_axis) const211 bool NormFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type,
212 float *epsilon, int *begin_norm_axis, int *begin_params_axis) const {
213 MS_ASSERT(equiv != nullptr);
214 MS_ASSERT(epsilon != nullptr);
215 MS_ASSERT(type != nullptr);
216 MS_ASSERT(begin_norm_axis != nullptr);
217 MS_ASSERT(begin_params_axis != nullptr);
218 // beta
219 auto beta_node = utils::cast<AnfNodePtr>((*equiv)[beta_]);
220 MS_ASSERT(beta_node != nullptr);
221 if (!beta_node->isa<Parameter>()) {
222 return false;
223 }
224 auto beta_param = beta_node->cast<ParameterPtr>()->default_param();
225 MS_CHECK_TRUE_RET(beta_param != nullptr, false);
226 auto beta_tensor = beta_param->cast<tensor::TensorPtr>();
227 MS_CHECK_TRUE_RET(beta_tensor != nullptr, false);
228 std::vector<int> beta_shape;
229 std::transform(beta_tensor->shape().begin(), beta_tensor->shape().end(), std::back_inserter(beta_shape),
230 [](int64_t val) { return static_cast<int>(val); });
231 // gamma
232 auto gamma_node = utils::cast<AnfNodePtr>((*equiv)[gamma_]);
233 MS_ASSERT(gamma_node != nullptr);
234 if (!gamma_node->isa<Parameter>()) {
235 return false;
236 }
237 auto gamma_param = gamma_node->cast<ParameterPtr>()->default_param();
238 MS_CHECK_TRUE_RET(gamma_param != nullptr, false);
239 auto gamma_tensor = gamma_param->cast<tensor::TensorPtr>();
240 MS_CHECK_TRUE_RET(gamma_tensor != nullptr, false);
241 std::vector<int> gamma_shape;
242 std::transform(gamma_tensor->shape().begin(), gamma_tensor->shape().end(), std::back_inserter(gamma_shape),
243 [](int64_t val) { return static_cast<int>(val); });
244 // epsilon
245 auto epsilon_node = utils::cast<AnfNodePtr>((*equiv)[epsilon_]);
246 MS_ASSERT(epsilon_node != nullptr);
247 if (!epsilon_node->isa<Parameter>()) {
248 return false;
249 }
250 auto epsilon_param = epsilon_node->cast<ParameterPtr>()->default_param();
251 MS_CHECK_TRUE_RET(epsilon_param != nullptr, false);
252 auto epsilon_tensor = epsilon_param->cast<tensor::TensorPtr>();
253 MS_CHECK_TRUE_RET(epsilon_tensor != nullptr, false);
254 auto epsilon_shape = epsilon_tensor->shape();
255 // mean2
256 std::vector<int> mean2_axes;
257 if (!IsReduceNode(equiv, mean2_, mean2_axes_, &mean2_axes)) {
258 return false;
259 }
260 // mean1
261 std::vector<int> mean1_axes;
262 if (!IsReduceNode(equiv, mean1_, mean1_axes_, &mean1_axes)) {
263 return false;
264 }
265 auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
266 MS_ASSERT(input_node != nullptr);
267 if (!utils::isa<CNodePtr>(input_node)) {
268 return false;
269 }
270 auto input_cnode = input_node->cast<CNodePtr>();
271 if (mean1_axes != mean2_axes) {
272 return false;
273 }
274 if (gamma_shape != beta_shape) {
275 return false;
276 }
277 if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) {
278 MS_CHECK_TRUE_RET(epsilon_tensor->data_c() != nullptr, false);
279 auto epsilon_data = reinterpret_cast<float *>(epsilon_tensor->data_c());
280 *epsilon = epsilon_data[0];
281 } else {
282 return false;
283 }
284 return GetNormTypeAndAxis(func_graph, input_cnode, mean1_axes, gamma_shape, type, begin_norm_axis, begin_params_axis);
285 }
286
287 namespace {
CommonShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)288 int CommonShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
289 MS_ASSERT(in_shape_size.size() > 0);
290 return in_shape_size.at(0);
291 }
292
ExpandDimsShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)293 int ExpandDimsShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
294 MS_ASSERT(in_shape_size.size() > 0);
295 return in_shape_size.at(0) + 1;
296 }
297
StridedSliceShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)298 int StridedSliceShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
299 MS_ASSERT(in_shape_size.size() > 0);
300 MS_ASSERT(primitive.value.AsStridedSlice() != nullptr);
301 auto new_axis_mask = static_cast<size_t>(primitive.value.AsStridedSlice()->new_axis_mask);
302 auto add_dims = 0;
303 while (new_axis_mask != 0) {
304 new_axis_mask = (new_axis_mask - 1) & new_axis_mask;
305 add_dims++;
306 }
307 return in_shape_size.at(0) + add_dims;
308 }
309
MatMulShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)310 int MatMulShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
311 MS_ASSERT(in_shape_size.size() > 1);
312 return in_shape_size[0];
313 }
314
ReShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)315 int ReShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
316 MS_ASSERT(in_shape_size.size() > 1);
317 return in_shape_size[1];
318 }
319
StackSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)320 int StackSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
321 MS_ASSERT(in_shape_size.size() > 1);
322 return std::accumulate(in_shape_size.begin(), in_shape_size.end(), 0);
323 }
324
SqueezeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)325 int SqueezeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
326 MS_ASSERT(in_shape_size.size() > 0);
327 auto axis = primitive.value.AsSqueeze()->axis;
328 if (axis.empty()) {
329 return 0;
330 }
331 return in_shape_size.at(0) - axis.size();
332 }
333
OneHotSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)334 int OneHotSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
335 MS_ASSERT(in_shape_size.size() > 0);
336 return in_shape_size.at(0) + 1;
337 }
338
FillShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)339 int FillShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
340 MS_ASSERT(in_shape_size.size() > 1);
341 return in_shape_size.at(1);
342 }
343
ShapeOpSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)344 int ShapeOpSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) { return 1; }
345
BroadcastShapeSizeInfer(const std::vector<int> & in_shape_size,const schema::PrimitiveT & primitive)346 int BroadcastShapeSizeInfer(const std::vector<int> &in_shape_size, const schema::PrimitiveT &primitive) {
347 MS_ASSERT(in_shape_size.size() > 1);
348 int result = 0;
349 for (auto shape_size : in_shape_size) {
350 result = std::max(result, shape_size);
351 }
352 return result;
353 }
354 } // namespace
355
ShapeSizeInfer(const FuncGraphPtr & func_graph) const356 std::map<string, int> NormFusion::ShapeSizeInfer(const FuncGraphPtr &func_graph) const {
357 MS_ASSERT(func_graph != nullptr);
358 std::map<string, int> node_shape_size;
359 std::map<string, std::vector<int>> node_shape;
360 auto node_list = TopoSort(func_graph->get_return());
361 for (auto &node : node_list) {
362 if (!utils::isa<CNodePtr>(node)) {
363 continue;
364 }
365 auto cnode = node->cast<CNodePtr>();
366 MS_ASSERT(cnode != nullptr);
367 auto prim_t = lite::GetPrimitiveT(cnode->input(0));
368 if (prim_t == nullptr) {
369 continue;
370 }
371 auto prim_type = prim_t->value.type;
372 auto shape_size_infer_iter = shape_size_infer_registry_.find(prim_type);
373 if (shape_size_infer_iter == shape_size_infer_registry_.end()) {
374 continue;
375 }
376
377 // specific op infer shape
378 if (prim_type == schema::PrimitiveType_Shape) {
379 tensor::TensorPtr tensor_info;
380 auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, 1);
381 if (ret == RET_OK) {
382 node_shape[cnode->fullname_with_scope()] = {static_cast<int>(tensor_info->shape().size())};
383 } else if (node_shape_size.find(cnode->input(1)->fullname_with_scope()) != node_shape_size.end()) {
384 node_shape[cnode->fullname_with_scope()] = {node_shape_size[cnode->input(1)->fullname_with_scope()]};
385 }
386 } else if (prim_type == schema::PrimitiveType_StridedSlice) {
387 node_shape[cnode->fullname_with_scope()] = node_shape[cnode->input(1)->fullname_with_scope()];
388 } else if (prim_type == schema::PrimitiveType_Stack) {
389 auto shape = node_shape[cnode->input(1)->fullname_with_scope()];
390 shape.insert(shape.begin(), cnode->size() - 1);
391 node_shape[cnode->fullname_with_scope()] = shape;
392 }
393
394 // Get in node shape size
395 std::vector<int> in_shape_sizes;
396 for (size_t i = 1; i < cnode->size(); i++) {
397 int in_shape_size = 0;
398 if (utils::isa<CNodePtr>(cnode->input(i))) {
399 in_shape_size = node_shape_size[cnode->input(i)->fullname_with_scope()];
400 // second input of reshape is shape
401 if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT &&
402 node_shape.find(cnode->input(i)->fullname_with_scope()) != node_shape.end()) {
403 in_shape_size = node_shape[cnode->input(i)->fullname_with_scope()].at(0);
404 }
405 } else {
406 tensor::TensorPtr tensor_info;
407 auto ret = GetTensorInfoFromAbstract(&tensor_info, cnode, i);
408 if (ret == RET_OK) {
409 in_shape_size = tensor_info->shape().size();
410 // second input of reshape is shape
411 if (prim_type == schema::PrimitiveType_Reshape && i == THIRD_INPUT) {
412 in_shape_size = tensor_info->shape().at(0);
413 }
414 }
415 }
416 in_shape_sizes.emplace_back(in_shape_size);
417 }
418 // Cal shape size infer function
419 auto shape_size_infer_func = shape_size_infer_iter->second;
420 auto shape_size = shape_size_infer_func(in_shape_sizes, *prim_t);
421 // Update node shape size map
422 node_shape_size[cnode->fullname_with_scope()] = shape_size;
423 }
424 return node_shape_size;
425 }
426
InitShapeSizeInferFuncMap()427 void NormFusion::InitShapeSizeInferFuncMap() {
428 if (!shape_size_infer_registry_.empty()) {
429 return;
430 }
431 shape_size_infer_registry_[schema::PrimitiveType_Activation] = CommonShapeSizeInfer;
432 shape_size_infer_registry_[schema::PrimitiveType_AddFusion] = BroadcastShapeSizeInfer;
433 shape_size_infer_registry_[schema::PrimitiveType_BiasAdd] = CommonShapeSizeInfer;
434 shape_size_infer_registry_[schema::PrimitiveType_Stack] = StackSizeInfer;
435 shape_size_infer_registry_[schema::PrimitiveType_Cast] = CommonShapeSizeInfer;
436 shape_size_infer_registry_[schema::PrimitiveType_Concat] = CommonShapeSizeInfer;
437 shape_size_infer_registry_[schema::PrimitiveType_ExpandDims] = ExpandDimsShapeSizeInfer;
438 shape_size_infer_registry_[schema::PrimitiveType_Fill] = FillShapeSizeInfer;
439 shape_size_infer_registry_[schema::PrimitiveType_LayerNormFusion] = CommonShapeSizeInfer;
440 shape_size_infer_registry_[schema::PrimitiveType_MatMulFusion] = MatMulShapeSizeInfer;
441 shape_size_infer_registry_[schema::PrimitiveType_MulFusion] = BroadcastShapeSizeInfer;
442 shape_size_infer_registry_[schema::PrimitiveType_OneHot] = OneHotSizeInfer;
443 shape_size_infer_registry_[schema::PrimitiveType_ReduceFusion] = CommonShapeSizeInfer;
444 shape_size_infer_registry_[schema::PrimitiveType_Reshape] = ReShapeSizeInfer;
445 shape_size_infer_registry_[schema::PrimitiveType_Shape] = ShapeOpSizeInfer;
446 shape_size_infer_registry_[schema::PrimitiveType_SliceFusion] = CommonShapeSizeInfer;
447 shape_size_infer_registry_[schema::PrimitiveType_Softmax] = CommonShapeSizeInfer;
448 shape_size_infer_registry_[schema::PrimitiveType_Squeeze] = SqueezeSizeInfer;
449 shape_size_infer_registry_[schema::PrimitiveType_StridedSlice] = StridedSliceShapeSizeInfer;
450 shape_size_infer_registry_[schema::PrimitiveType_Transpose] = CommonShapeSizeInfer;
451 shape_size_infer_registry_[schema::PrimitiveType_TileFusion] = CommonShapeSizeInfer;
452 shape_size_infer_registry_[schema::PrimitiveType_SquaredDifference] = CommonShapeSizeInfer;
453 shape_size_infer_registry_[schema::PrimitiveType_Rsqrt] = CommonShapeSizeInfer;
454 shape_size_infer_registry_[schema::PrimitiveType_SubFusion] = BroadcastShapeSizeInfer;
455 shape_size_infer_registry_[schema::PrimitiveType_PadFusion] = CommonShapeSizeInfer;
456 shape_size_infer_registry_[schema::PrimitiveType_PowFusion] = CommonShapeSizeInfer;
457 }
458
CreateActivationNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node) const459 CNodePtr NormFusion::CreateActivationNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const {
460 auto act_primitive = std::make_shared<ops::Activation>();
461 MS_CHECK_TRUE_RET(act_primitive != nullptr, nullptr);
462 act_primitive->set_activation_type(add_act_type_);
463 auto act_primitive_c = act_primitive->GetPrim();
464 MS_CHECK_TRUE_RET(act_primitive_c != nullptr, nullptr);
465 auto value_node = NewValueNode(act_primitive_c);
466 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
467 std::vector<AnfNodePtr> new_node_inputs = {value_node};
468 new_node_inputs.push_back(node);
469 auto new_node = func_graph->NewCNode(new_node_inputs);
470 return new_node;
471 }
472
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const473 const AnfNodePtr NormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
474 const EquivPtr &equiv) const {
475 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
476 MS_LOG(ERROR) << "input param is nullptr, do norm fusion failed.";
477 return nullptr;
478 }
479 if (!utils::isa<CNodePtr>(node)) {
480 return nullptr;
481 }
482 auto add2_cnode = node->cast<CNodePtr>();
483 auto add2_primitive = ops::GetOperator<ops::AddFusion>(add2_cnode->input(0));
484 MS_CHECK_TRUE_RET(add2_primitive != nullptr, nullptr);
485 auto add2_primitive_c = add2_primitive->GetPrim();
486 if (add2_primitive_c->GetAttr(ops::kActivationType) != nullptr) {
487 add_act_type_ = add2_primitive->get_activation_type();
488 }
489 if (IsMarkedTrainOp(add2_cnode)) {
490 return nullptr;
491 }
492 float epsilon = 0.0f;
493 int begin_norm_axis = 0;
494 int begin_params_axis = 0;
495 schema::PrimitiveType type = schema::PrimitiveType_NONE;
496 if (!CheckPattern(func_graph, equiv, &type, &epsilon, &begin_norm_axis, &begin_params_axis)) {
497 return nullptr;
498 }
499 auto norm_cnode = CreateNormNode(func_graph, equiv, type, epsilon, begin_norm_axis, begin_params_axis);
500 if (norm_cnode == nullptr) {
501 MS_LOG(DEBUG) << "create norm cnode failed";
502 return nullptr;
503 }
504 MS_CHECK_TRUE_RET(add2_cnode->abstract() != nullptr, nullptr);
505 norm_cnode->set_abstract(add2_cnode->abstract()->Clone());
506 if (type == schema::PrimitiveType_LayerNormFusion) {
507 norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope());
508 MS_LOG(DEBUG) << "layer_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
509 } else if (type == schema::PrimitiveType_InstanceNorm) {
510 norm_cnode->set_fullname_with_scope("instance_norm_" + add2_cnode->fullname_with_scope());
511 MS_LOG(DEBUG) << "instance_norm node:" << norm_cnode->fullname_with_scope() << " fusion success";
512 }
513 // insert relu node
514 if (add_act_type_ != ActivationType::NO_ACTIVATION) {
515 auto new_act_node = CreateActivationNode(func_graph, node);
516 if (new_act_node == nullptr) {
517 MS_LOG(ERROR) << "create act cnode failed.";
518 return nullptr;
519 }
520 new_act_node->set_fullname_with_scope("relu_" + add2_cnode->fullname_with_scope());
521 auto manager = func_graph->manager();
522 if (manager == nullptr) {
523 manager = Manage(func_graph, true);
524 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
525 }
526 auto node_users = manager->node_users()[add2_cnode];
527 for (auto &node_user : node_users) {
528 auto next_cnode = node_user.first->cast<CNodePtr>();
529 MS_CHECK_TRUE_RET(next_cnode != nullptr, nullptr);
530 manager->SetEdge(next_cnode, node_user.second, new_act_node);
531 }
532 UpdateManager(func_graph);
533 }
534 return norm_cnode;
535 }
536
DefinePattern() const537 const BaseRef TfNormFusion::DefinePattern() const {
538 if (!Init()) {
539 MS_LOG(ERROR) << "initial member failed.";
540 return {};
541 }
542 VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
543 auto is_squared_diffference = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSquaredDifference>);
544 MS_CHECK_TRUE_RET(is_squared_diffference != nullptr, {});
545 VectorRef squared_diffference1_ref = VectorRef({is_squared_diffference, input_, mean1_ref});
546 VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref, mean2_axes_});
547 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
548 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
549 VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
550 auto is_rsqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRsqrt>);
551 MS_CHECK_TRUE_RET(is_rsqrt != nullptr, {});
552 VectorRef rsqrt1_ref = VectorRef({is_rsqrt, add1_ref});
553 auto is_mul2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
554 MS_CHECK_TRUE_RET(is_mul2 != nullptr, {});
555 VectorRef mul2_ref = VectorRef({is_mul2, rsqrt1_ref, gamma_});
556 auto is_mul1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
557 MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
558 VectorRef mul1_ref = VectorRef({is_mul1, input_, mul2_ref});
559 auto is_mul3 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
560 MS_CHECK_TRUE_RET(is_mul3 != nullptr, {});
561 VectorRef mul3_ref = VectorRef({is_mul3, mean1_ref, mul2_ref});
562 auto is_sub = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
563 MS_CHECK_TRUE_RET(is_sub != nullptr, {});
564 VectorRef sub1_ref = VectorRef({is_sub, beta_, mul3_ref});
565 auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
566 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
567 VectorRef add2_ref = VectorRef({is_add2, mul1_ref, sub1_ref});
568 return add2_ref;
569 }
570
DefinePattern() const571 const BaseRef OnnxLayerNormFusion::DefinePattern() const {
572 if (!Init()) {
573 MS_LOG(ERROR) << "initial member failed.";
574 return {};
575 }
576 VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
577 auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
578 MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
579 VectorRef sub1_ref = VectorRef({is_sub1, input_, mean1_ref});
580 auto is_sub2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
581 MS_CHECK_TRUE_RET(is_sub2 != nullptr, {});
582 VectorRef sub2_ref = VectorRef({is_sub2, input_, mean1_ref});
583 auto is_pow = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>);
584 MS_CHECK_TRUE_RET(is_pow != nullptr, {});
585 auto is_var = std::make_shared<Var>();
586 MS_CHECK_TRUE_RET(is_var != nullptr, {});
587 VectorRef pow_ref = VectorRef({is_pow, sub2_ref, is_var});
588 VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
589 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
590 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
591 VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
592 auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
593 MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
594 VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
595 auto is_div = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
596 MS_CHECK_TRUE_RET(is_div != nullptr, {});
597 VectorRef div_ref = VectorRef({is_div, sub1_ref, sqrt_ref});
598 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
599 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
600 VectorRef mul_ref = VectorRef({is_mul, gamma_, div_ref});
601 auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
602 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
603 VectorRef add2_ref = VectorRef({is_add2, mul_ref, beta_});
604 return add2_ref;
605 }
606
607 // little different from OnnxLayerNormFusion on mul's inputs order
DefinePattern() const608 const BaseRef OnnxLayerNormFusion2::DefinePattern() const {
609 if (!Init()) {
610 MS_LOG(ERROR) << "initial member failed.";
611 return {};
612 }
613 VectorRef mean1_ref = VectorRef({mean1_, input_, mean1_axes_});
614 auto is_sub1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
615 MS_CHECK_TRUE_RET(is_sub1 != nullptr, {});
616 VectorRef sub1_ref = VectorRef({is_sub1, input_, mean1_ref});
617 auto is_sub2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSubFusion>);
618 MS_CHECK_TRUE_RET(is_sub2 != nullptr, {});
619 VectorRef sub2_ref = VectorRef({is_sub2, input_, mean1_ref});
620 auto is_pow = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimPowFusion>);
621 MS_CHECK_TRUE_RET(is_pow != nullptr, {});
622 auto is_var = std::make_shared<Var>();
623 MS_CHECK_TRUE_RET(is_var != nullptr, {});
624 VectorRef pow_ref = VectorRef({is_pow, sub2_ref, is_var});
625 VectorRef mean2_ref = VectorRef({mean2_, pow_ref, mean2_axes_});
626 auto is_add1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
627 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
628 VectorRef add1_ref = VectorRef({is_add1, mean2_ref, epsilon_});
629 auto is_sqrt = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>);
630 MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
631 VectorRef sqrt_ref = VectorRef({is_sqrt, add1_ref});
632 auto is_div = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>);
633 MS_CHECK_TRUE_RET(is_div != nullptr, {});
634 VectorRef div_ref = VectorRef({is_div, sub1_ref, sqrt_ref});
635 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
636 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
637 VectorRef mul_ref = VectorRef({is_mul, div_ref, gamma_});
638 auto is_add2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
639 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
640 VectorRef add2_ref = VectorRef({is_add2, mul_ref, beta_});
641 return add2_ref;
642 }
643 } // namespace opt
644 } // namespace mindspore
645