• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/mul_add_fusion.h"
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "nnacl/op_base.h"
24 #include "ops/fusion/add_fusion.h"
25 #include "ops/fusion/mul_fusion.h"
26 #include "ops/fusion/scale_fusion.h"
27 #include "ops/op_utils.h"
28 #include "tools/lite_exporter/fetch_content.h"
29 #include "tools/optimizer/common/gllo_utils.h"
30 
31 namespace mindspore::opt {
DefineMulFirstPattern() const32 VectorRef MulAddFusion::DefineMulFirstPattern() const {
33   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
34   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
35   auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
36   MS_CHECK_TRUE_RET(is_add != nullptr, {});
37   auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
38   MS_CHECK_TRUE_RET(is_const != nullptr, {});
39   return VectorRef({is_add, is_mul, is_const});
40 }
41 
DefineMulSecondPattern() const42 VectorRef MulAddFusion::DefineMulSecondPattern() const {
43   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
44   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
45   auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
46   MS_CHECK_TRUE_RET(is_const != nullptr, {});
47   auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
48   MS_CHECK_TRUE_RET(is_add != nullptr, {});
49   return VectorRef({is_add, is_const, is_mul});
50 }
51 
DefinePatterns() const52 std::unordered_map<std::string, VectorRef> MulAddFusion::DefinePatterns() const {
53   std::unordered_map<std::string, VectorRef> patterns;
54   patterns["MulFirstPatternName"] = DefineMulFirstPattern();
55   patterns["MulSecondPatternName"] = DefineMulSecondPattern();
56   return patterns;
57 }
58 
CheckAddNode(const mindspore::CNodePtr & cnode) const59 bool MulAddFusion::CheckAddNode(const mindspore::CNodePtr &cnode) const {
60   MS_CHECK_TRUE_RET(cnode != nullptr, false);
61   if (cnode->size() != kInputSizeThree) {
62     MS_LOG(DEBUG) << "Add op is null or has error input size";
63     return false;
64   }
65   if (IsMarkedTrainOp(cnode)) {
66     return false;
67   }
68   auto add_primitive = ops::GetOperator<ops::AddFusion>(cnode->input(0));
69   MS_CHECK_TRUE_RET(add_primitive != nullptr, false);
70   auto add_primitive_c = add_primitive->GetPrim();
71   MS_CHECK_TRUE_RET(add_primitive_c != nullptr, false);
72   auto quant_attr = add_primitive_c->GetAttr("quant_params");
73   if (quant_attr != nullptr) {
74     auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
75     MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
76     auto quant_params = quant_param_holder->get_input_quant_params();
77     bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> &params) {
78       return !params.empty() && params.front().inited;
79     });
80     if (is_quant) {
81       return false;
82     }
83   }
84 
85   ActivationType add_act_type = ActivationType::NO_ACTIVATION;
86   if (add_primitive_c->GetAttr(ops::kActivationType) != nullptr) {
87     add_act_type = add_primitive->get_activation_type();
88     if (add_act_type != ActivationType::RELU && add_act_type != ActivationType::RELU6 &&
89         add_act_type != ActivationType::NO_ACTIVATION) {
90       MS_LOG(DEBUG) << "Only support add node with relu or relu6 or no activation";
91       return false;
92     }
93   }
94   scale_act_type_ = add_act_type;
95   return true;
96 }
97 
CheckMulNode(const mindspore::FuncGraphPtr & func_graph,const mindspore::CNodePtr & cnode) const98 bool MulAddFusion::CheckMulNode(const mindspore::FuncGraphPtr &func_graph, const mindspore::CNodePtr &cnode) const {
99   MS_ASSERT(func_graph != nullptr);
100   MS_CHECK_TRUE_RET(cnode != nullptr, false);
101   if (IsMultiOutputTensors(func_graph, cnode)) {
102     MS_LOG(DEBUG) << "Mul op has multi-output";
103     return false;
104   }
105   if (IsMarkedTrainOp(cnode)) {
106     return false;
107   }
108   auto mul_primitive = ops::GetOperator<ops::MulFusion>(cnode->input(0));
109   MS_CHECK_TRUE_RET(mul_primitive != nullptr, false);
110   auto mul_primitive_c = mul_primitive->GetPrim();
111   MS_CHECK_TRUE_RET(mul_primitive_c != nullptr, false);
112   auto quant_attr = mul_primitive_c->GetAttr("quant_params");
113   if (quant_attr != nullptr) {
114     auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
115     MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
116     auto quant_params = quant_param_holder->get_input_quant_params();
117     bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> &params) {
118       return !params.empty() && params.front().inited;
119     });
120     if (is_quant) {
121       return false;
122     }
123   }
124 
125   if (mul_primitive_c->GetAttr(ops::kActivationType) != nullptr &&
126       mul_primitive->get_activation_type() != ActivationType::NO_ACTIVATION) {
127     MS_LOG(DEBUG) << "Only support mul node with no activation";
128     return false;
129   }
130   if (cnode->size() != kInputSizeThree) {
131     MS_LOG(DEBUG) << "Mul op is null or has error input size";
132     return false;
133   }
134   return true;
135 }
136 
AdjustScaleBiasTensorShape(size_t * axis_offset) const137 bool MulAddFusion::AdjustScaleBiasTensorShape(size_t *axis_offset) const {
138   MS_CHECK_TRUE_RET(axis_offset != nullptr, false);
139   auto scale_shape = scale_tensor_->shape_c();
140   if (mul_input_shape_ == scale_shape) {
141     return true;
142   }
143   while (scale_shape.size() > DIMENSION_1D) {
144     bool begin_with_value_one = scale_shape.at(FIRST_INPUT) == DIMENSION_1D ? true : false;
145     bool end_with_value_one = scale_shape.at(scale_shape.size() - DIMENSION_1D) == DIMENSION_1D ? true : false;
146     if (!begin_with_value_one && !end_with_value_one) {
147       break;
148     }
149     if (begin_with_value_one) {
150       scale_shape.erase(scale_shape.begin());
151     }
152     if (end_with_value_one) {
153       scale_shape.erase(scale_shape.end() - DIMENSION_1D);
154       *axis_offset += DIMENSION_1D;
155     }
156   }
157   (void)scale_tensor_->set_shape(scale_shape);
158   (void)bias_tensor_->set_shape(scale_shape);
159 
160   // set shape for abstract
161   auto mul_abstract = mul_const_anode_->abstract();
162   MS_CHECK_TRUE_RET(mul_abstract != nullptr, false);
163   auto new_shape = std::make_shared<abstract::Shape>(scale_shape);
164   MS_CHECK_TRUE_RET(new_shape != nullptr, false);
165   mul_abstract->set_shape(new_shape);
166 
167   auto add_abstract = add_const_anode_->abstract();
168   MS_CHECK_TRUE_RET(add_abstract != nullptr, false);
169   auto new_add_shape = std::make_shared<abstract::Shape>(scale_shape);
170   MS_CHECK_TRUE_RET(new_add_shape != nullptr, false);
171   add_abstract->set_shape(new_add_shape);
172   return true;
173 }
174 
ScaleInputShapeValid(size_t * axis_offset) const175 bool MulAddFusion::ScaleInputShapeValid(size_t *axis_offset) const {
176   MS_ASSERT(scale_tensor_ != nullptr && bias_tensor_ != nullptr && axis_offset != nullptr);
177   if (scale_tensor_->shape_c() != bias_tensor_->shape_c()) {
178     return false;
179   }
180   // remove value 1 which is in the begin or the end of shape vector.
181   if (!AdjustScaleBiasTensorShape(axis_offset)) {
182     MS_LOG(ERROR) << "Adjust scale shape and bias shape failed.";
183     return false;
184   }
185   auto scale_shape = scale_tensor_->shape_c();
186   if (mul_input_shape_.size() < scale_shape.size() || scale_shape.size() == 0) {
187     return false;
188   }
189   size_t rank_diff = mul_input_shape_.size() - scale_shape.size();
190   for (size_t i = 0; i < scale_shape.size(); ++i) {
191     if (i + rank_diff < *axis_offset) {
192       MS_LOG(ERROR) << "Sub overflow occur may cause index out of range.";
193       return false;
194     }
195     if (mul_input_shape_[i + rank_diff - *axis_offset] != scale_shape[i]) {
196       return false;
197     }
198   }
199   return true;
200 }
201 
MulInputAnodeIsInferred(const AnfNodePtr & mul_input_anode) const202 bool MulAddFusion::MulInputAnodeIsInferred(const AnfNodePtr &mul_input_anode) const {
203   auto mul_input_cnode = mul_input_anode->cast<CNodePtr>();
204   MS_EXCEPTION_IF_NULL(mul_input_cnode);
205   auto prim = GetValueNode<PrimitivePtr>(mul_input_cnode->input(0));
206   MS_CHECK_TRUE_RET(prim != nullptr, false);
207   auto is_inferred = prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(prim->GetAttr(kInferDone));
208   return is_inferred;
209 }
210 
CopyNodeFormat(CNodePtr node,mindspore::ops::PrimitiveCPtr prim) const211 bool MulAddFusion::CopyNodeFormat(CNodePtr node, mindspore::ops::PrimitiveCPtr prim) const {
212   auto src_prim = GetValueNode<PrimitiveCPtr>(node->input(0));
213   MS_CHECK_TRUE_RET(src_prim != nullptr, false);
214   if (src_prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
215     auto value = src_prim->GetAttr(mindspore::ops::kFormat);
216     MS_CHECK_TRUE_RET(value != nullptr, false);
217     if (value->isa<mindspore::Int64Imm>()) {
218       auto format = GetValue<int64_t>(value);
219       prim->AddAttr(mindspore::ops::kFormat, MakeValue(format));
220     }
221   }
222   return true;
223 }
224 
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const225 AnfNodePtr MulAddFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
226                                  const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
227   if (func_graph == nullptr || node == nullptr) {
228     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
229     return nullptr;
230   }
231   auto add_cnode = node->cast<CNodePtr>();
232   if (!CheckAddNode(add_cnode)) {
233     MS_LOG(DEBUG) << "Add op is not suit for mul-add-fusion: " << node->fullname_with_scope();
234     return nullptr;
235   }
236 
237   auto mul_node = utils::isa<CNodePtr>(add_cnode->input(SECOND_INPUT)) ? add_cnode->input(SECOND_INPUT)
238                                                                        : add_cnode->input(THIRD_INPUT);
239   MS_CHECK_TRUE_RET(mul_node != nullptr, nullptr);
240   auto mul_cnode = mul_node->cast<CNodePtr>();
241   if (!CheckMulNode(func_graph, mul_cnode)) {
242     MS_LOG(DEBUG) << "Mul op is not suit for mul-add-fusion: " << mul_cnode->fullname_with_scope();
243     return nullptr;
244   }
245 
246   auto mul_input_anode = mul_cnode->input(SECOND_INPUT);
247   if (utils::isa<ParameterPtr>(mul_input_anode)) {
248     auto param_node = mul_input_anode->cast<ParameterPtr>();
249     MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
250     mul_const_anode_ = param_node->has_default() ? mul_input_anode : mul_cnode->input(THIRD_INPUT);
251     mul_input_anode = param_node->has_default() ? mul_cnode->input(THIRD_INPUT) : mul_input_anode;
252   } else if (utils::isa<CNodePtr>(mul_input_anode)) {
253     mul_const_anode_ = mul_cnode->input(THIRD_INPUT);
254   }
255   size_t add_const_idx = utils::isa<CNodePtr>(add_cnode->input(SECOND_INPUT)) ? THIRD_INPUT : SECOND_INPUT;
256   add_const_anode_ = add_cnode->input(add_const_idx);
257   MS_CHECK_TRUE_RET(mul_const_anode_ != nullptr && add_const_anode_ != nullptr, nullptr);
258   bias_tensor_ = GetTensorInfo(add_const_anode_);
259   scale_tensor_ = GetTensorInfo(mul_const_anode_);
260   MS_CHECK_TRUE_RET(bias_tensor_ != nullptr, nullptr);
261   MS_CHECK_TRUE_RET(scale_tensor_ != nullptr, nullptr);
262   MS_CHECK_TRUE_RET(mul_input_anode != nullptr, nullptr);
263   if (mul_input_anode->isa<CNode>()) {
264     if (!MulInputAnodeIsInferred(mul_input_anode)) {
265       MS_LOG(DEBUG) << "mul_input_anode is not inferred, don't perform the ScaleInputShapeValid method.";
266       return nullptr;
267     }
268   }
269   if (FetchShapeFromAbstract(mul_input_anode->abstract(), &mul_input_shape_) != lite::RET_OK) {
270     return nullptr;
271   }
272   // scale requires scale shape tail sub of input shape, scale shape same as bias shape
273   size_t axis_offset = 0;
274   if (!ScaleInputShapeValid(&axis_offset)) {
275     MS_LOG(DEBUG) << "Check input shape, scale shape and bias shape failed";
276     return nullptr;
277   }
278   // create scale primitive
279   auto scale_primitive = std::make_shared<ops::ScaleFusion>();
280   if (scale_primitive == nullptr) {
281     MS_LOG(ERROR) << "new scale primitive failed";
282     return nullptr;
283   }
284   scale_primitive->set_activation_type(scale_act_type_);
285   auto scale_primitive_c = scale_primitive->GetPrim();
286   MS_CHECK_TRUE_RET(scale_primitive_c != nullptr, nullptr);
287   if (INT_ADD_OVERFLOW_THRESHOLD(bias_tensor_->shape_c().size(), axis_offset, SIZE_MAX)) {
288     MS_LOG(ERROR) << "Add overflow: " << bias_tensor_->shape_c().size() << " + " << axis_offset;
289     return nullptr;
290   }
291   scale_primitive->set_axis(-(static_cast<int64_t>(bias_tensor_->shape_c().size() + axis_offset)));
292 
293   // copy the format of add node to scale node
294   if (CopyNodeFormat(add_cnode, scale_primitive_c)) {
295     MS_LOG(WARNING) << "Copy original node format failed";
296   }
297 
298   // create scale op
299   auto scale_node = func_graph->NewCNode(scale_primitive_c, {mul_input_anode, mul_const_anode_, add_const_anode_});
300   scale_node->set_abstract(add_cnode->abstract());
301   return scale_node;
302 }
303 }  // namespace mindspore::opt
304