1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/mul_add_fusion.h"
19 #include <algorithm>
20 #include <memory>
21 #include <vector>
22 #include "mindspore/core/ops/lite_ops.h"
23 #include "nnacl/op_base.h"
24 #include "ops/fusion/add_fusion.h"
25 #include "ops/fusion/mul_fusion.h"
26 #include "ops/fusion/scale_fusion.h"
27 #include "ops/op_utils.h"
28 #include "tools/lite_exporter/fetch_content.h"
29 #include "tools/optimizer/common/gllo_utils.h"
30
31 namespace mindspore::opt {
DefineMulFirstPattern() const32 VectorRef MulAddFusion::DefineMulFirstPattern() const {
33 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
34 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
35 auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
36 MS_CHECK_TRUE_RET(is_add != nullptr, {});
37 auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
38 MS_CHECK_TRUE_RET(is_const != nullptr, {});
39 return VectorRef({is_add, is_mul, is_const});
40 }
41
DefineMulSecondPattern() const42 VectorRef MulAddFusion::DefineMulSecondPattern() const {
43 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
44 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
45 auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
46 MS_CHECK_TRUE_RET(is_const != nullptr, {});
47 auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
48 MS_CHECK_TRUE_RET(is_add != nullptr, {});
49 return VectorRef({is_add, is_const, is_mul});
50 }
51
DefinePatterns() const52 std::unordered_map<std::string, VectorRef> MulAddFusion::DefinePatterns() const {
53 std::unordered_map<std::string, VectorRef> patterns;
54 patterns["MulFirstPatternName"] = DefineMulFirstPattern();
55 patterns["MulSecondPatternName"] = DefineMulSecondPattern();
56 return patterns;
57 }
58
CheckAddNode(const mindspore::CNodePtr & cnode) const59 bool MulAddFusion::CheckAddNode(const mindspore::CNodePtr &cnode) const {
60 MS_CHECK_TRUE_RET(cnode != nullptr, false);
61 if (cnode->size() != kInputSizeThree) {
62 MS_LOG(DEBUG) << "Add op is null or has error input size";
63 return false;
64 }
65 if (IsMarkedTrainOp(cnode)) {
66 return false;
67 }
68 auto add_primitive = ops::GetOperator<ops::AddFusion>(cnode->input(0));
69 MS_CHECK_TRUE_RET(add_primitive != nullptr, false);
70 auto add_primitive_c = add_primitive->GetPrim();
71 MS_CHECK_TRUE_RET(add_primitive_c != nullptr, false);
72 auto quant_attr = add_primitive_c->GetAttr("quant_params");
73 if (quant_attr != nullptr) {
74 auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
75 MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
76 auto quant_params = quant_param_holder->get_input_quant_params();
77 bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> ¶ms) {
78 return !params.empty() && params.front().inited;
79 });
80 if (is_quant) {
81 return false;
82 }
83 }
84
85 ActivationType add_act_type = ActivationType::NO_ACTIVATION;
86 if (add_primitive_c->GetAttr(ops::kActivationType) != nullptr) {
87 add_act_type = add_primitive->get_activation_type();
88 if (add_act_type != ActivationType::RELU && add_act_type != ActivationType::RELU6 &&
89 add_act_type != ActivationType::NO_ACTIVATION) {
90 MS_LOG(DEBUG) << "Only support add node with relu or relu6 or no activation";
91 return false;
92 }
93 }
94 scale_act_type_ = add_act_type;
95 return true;
96 }
97
CheckMulNode(const mindspore::FuncGraphPtr & func_graph,const mindspore::CNodePtr & cnode) const98 bool MulAddFusion::CheckMulNode(const mindspore::FuncGraphPtr &func_graph, const mindspore::CNodePtr &cnode) const {
99 MS_ASSERT(func_graph != nullptr);
100 MS_CHECK_TRUE_RET(cnode != nullptr, false);
101 if (IsMultiOutputTensors(func_graph, cnode)) {
102 MS_LOG(DEBUG) << "Mul op has multi-output";
103 return false;
104 }
105 if (IsMarkedTrainOp(cnode)) {
106 return false;
107 }
108 auto mul_primitive = ops::GetOperator<ops::MulFusion>(cnode->input(0));
109 MS_CHECK_TRUE_RET(mul_primitive != nullptr, false);
110 auto mul_primitive_c = mul_primitive->GetPrim();
111 MS_CHECK_TRUE_RET(mul_primitive_c != nullptr, false);
112 auto quant_attr = mul_primitive_c->GetAttr("quant_params");
113 if (quant_attr != nullptr) {
114 auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
115 MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
116 auto quant_params = quant_param_holder->get_input_quant_params();
117 bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> ¶ms) {
118 return !params.empty() && params.front().inited;
119 });
120 if (is_quant) {
121 return false;
122 }
123 }
124
125 if (mul_primitive_c->GetAttr(ops::kActivationType) != nullptr &&
126 mul_primitive->get_activation_type() != ActivationType::NO_ACTIVATION) {
127 MS_LOG(DEBUG) << "Only support mul node with no activation";
128 return false;
129 }
130 if (cnode->size() != kInputSizeThree) {
131 MS_LOG(DEBUG) << "Mul op is null or has error input size";
132 return false;
133 }
134 return true;
135 }
136
AdjustScaleBiasTensorShape(size_t * axis_offset) const137 bool MulAddFusion::AdjustScaleBiasTensorShape(size_t *axis_offset) const {
138 MS_CHECK_TRUE_RET(axis_offset != nullptr, false);
139 auto scale_shape = scale_tensor_->shape_c();
140 if (mul_input_shape_ == scale_shape) {
141 return true;
142 }
143 while (scale_shape.size() > DIMENSION_1D) {
144 bool begin_with_value_one = scale_shape.at(FIRST_INPUT) == DIMENSION_1D ? true : false;
145 bool end_with_value_one = scale_shape.at(scale_shape.size() - DIMENSION_1D) == DIMENSION_1D ? true : false;
146 if (!begin_with_value_one && !end_with_value_one) {
147 break;
148 }
149 if (begin_with_value_one) {
150 scale_shape.erase(scale_shape.begin());
151 }
152 if (end_with_value_one) {
153 scale_shape.erase(scale_shape.end() - DIMENSION_1D);
154 *axis_offset += DIMENSION_1D;
155 }
156 }
157 (void)scale_tensor_->set_shape(scale_shape);
158 (void)bias_tensor_->set_shape(scale_shape);
159
160 // set shape for abstract
161 auto mul_abstract = mul_const_anode_->abstract();
162 MS_CHECK_TRUE_RET(mul_abstract != nullptr, false);
163 auto new_shape = std::make_shared<abstract::Shape>(scale_shape);
164 MS_CHECK_TRUE_RET(new_shape != nullptr, false);
165 mul_abstract->set_shape(new_shape);
166
167 auto add_abstract = add_const_anode_->abstract();
168 MS_CHECK_TRUE_RET(add_abstract != nullptr, false);
169 auto new_add_shape = std::make_shared<abstract::Shape>(scale_shape);
170 MS_CHECK_TRUE_RET(new_add_shape != nullptr, false);
171 add_abstract->set_shape(new_add_shape);
172 return true;
173 }
174
ScaleInputShapeValid(size_t * axis_offset) const175 bool MulAddFusion::ScaleInputShapeValid(size_t *axis_offset) const {
176 MS_ASSERT(scale_tensor_ != nullptr && bias_tensor_ != nullptr && axis_offset != nullptr);
177 if (scale_tensor_->shape_c() != bias_tensor_->shape_c()) {
178 return false;
179 }
180 // remove value 1 which is in the begin or the end of shape vector.
181 if (!AdjustScaleBiasTensorShape(axis_offset)) {
182 MS_LOG(ERROR) << "Adjust scale shape and bias shape failed.";
183 return false;
184 }
185 auto scale_shape = scale_tensor_->shape_c();
186 if (mul_input_shape_.size() < scale_shape.size() || scale_shape.size() == 0) {
187 return false;
188 }
189 size_t rank_diff = mul_input_shape_.size() - scale_shape.size();
190 for (size_t i = 0; i < scale_shape.size(); ++i) {
191 if (i + rank_diff < *axis_offset) {
192 MS_LOG(ERROR) << "Sub overflow occur may cause index out of range.";
193 return false;
194 }
195 if (mul_input_shape_[i + rank_diff - *axis_offset] != scale_shape[i]) {
196 return false;
197 }
198 }
199 return true;
200 }
201
MulInputAnodeIsInferred(const AnfNodePtr & mul_input_anode) const202 bool MulAddFusion::MulInputAnodeIsInferred(const AnfNodePtr &mul_input_anode) const {
203 auto mul_input_cnode = mul_input_anode->cast<CNodePtr>();
204 MS_EXCEPTION_IF_NULL(mul_input_cnode);
205 auto prim = GetValueNode<PrimitivePtr>(mul_input_cnode->input(0));
206 MS_CHECK_TRUE_RET(prim != nullptr, false);
207 auto is_inferred = prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(prim->GetAttr(kInferDone));
208 return is_inferred;
209 }
210
CopyNodeFormat(CNodePtr node,mindspore::ops::PrimitiveCPtr prim) const211 bool MulAddFusion::CopyNodeFormat(CNodePtr node, mindspore::ops::PrimitiveCPtr prim) const {
212 auto src_prim = GetValueNode<PrimitiveCPtr>(node->input(0));
213 MS_CHECK_TRUE_RET(src_prim != nullptr, false);
214 if (src_prim->GetAttr(mindspore::ops::kFormat) != nullptr) {
215 auto value = src_prim->GetAttr(mindspore::ops::kFormat);
216 MS_CHECK_TRUE_RET(value != nullptr, false);
217 if (value->isa<mindspore::Int64Imm>()) {
218 auto format = GetValue<int64_t>(value);
219 prim->AddAttr(mindspore::ops::kFormat, MakeValue(format));
220 }
221 }
222 return true;
223 }
224
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const225 AnfNodePtr MulAddFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
226 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
227 if (func_graph == nullptr || node == nullptr) {
228 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
229 return nullptr;
230 }
231 auto add_cnode = node->cast<CNodePtr>();
232 if (!CheckAddNode(add_cnode)) {
233 MS_LOG(DEBUG) << "Add op is not suit for mul-add-fusion: " << node->fullname_with_scope();
234 return nullptr;
235 }
236
237 auto mul_node = utils::isa<CNodePtr>(add_cnode->input(SECOND_INPUT)) ? add_cnode->input(SECOND_INPUT)
238 : add_cnode->input(THIRD_INPUT);
239 MS_CHECK_TRUE_RET(mul_node != nullptr, nullptr);
240 auto mul_cnode = mul_node->cast<CNodePtr>();
241 if (!CheckMulNode(func_graph, mul_cnode)) {
242 MS_LOG(DEBUG) << "Mul op is not suit for mul-add-fusion: " << mul_cnode->fullname_with_scope();
243 return nullptr;
244 }
245
246 auto mul_input_anode = mul_cnode->input(SECOND_INPUT);
247 if (utils::isa<ParameterPtr>(mul_input_anode)) {
248 auto param_node = mul_input_anode->cast<ParameterPtr>();
249 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
250 mul_const_anode_ = param_node->has_default() ? mul_input_anode : mul_cnode->input(THIRD_INPUT);
251 mul_input_anode = param_node->has_default() ? mul_cnode->input(THIRD_INPUT) : mul_input_anode;
252 } else if (utils::isa<CNodePtr>(mul_input_anode)) {
253 mul_const_anode_ = mul_cnode->input(THIRD_INPUT);
254 }
255 size_t add_const_idx = utils::isa<CNodePtr>(add_cnode->input(SECOND_INPUT)) ? THIRD_INPUT : SECOND_INPUT;
256 add_const_anode_ = add_cnode->input(add_const_idx);
257 MS_CHECK_TRUE_RET(mul_const_anode_ != nullptr && add_const_anode_ != nullptr, nullptr);
258 bias_tensor_ = GetTensorInfo(add_const_anode_);
259 scale_tensor_ = GetTensorInfo(mul_const_anode_);
260 MS_CHECK_TRUE_RET(bias_tensor_ != nullptr, nullptr);
261 MS_CHECK_TRUE_RET(scale_tensor_ != nullptr, nullptr);
262 MS_CHECK_TRUE_RET(mul_input_anode != nullptr, nullptr);
263 if (mul_input_anode->isa<CNode>()) {
264 if (!MulInputAnodeIsInferred(mul_input_anode)) {
265 MS_LOG(DEBUG) << "mul_input_anode is not inferred, don't perform the ScaleInputShapeValid method.";
266 return nullptr;
267 }
268 }
269 if (FetchShapeFromAbstract(mul_input_anode->abstract(), &mul_input_shape_) != lite::RET_OK) {
270 return nullptr;
271 }
272 // scale requires scale shape tail sub of input shape, scale shape same as bias shape
273 size_t axis_offset = 0;
274 if (!ScaleInputShapeValid(&axis_offset)) {
275 MS_LOG(DEBUG) << "Check input shape, scale shape and bias shape failed";
276 return nullptr;
277 }
278 // create scale primitive
279 auto scale_primitive = std::make_shared<ops::ScaleFusion>();
280 if (scale_primitive == nullptr) {
281 MS_LOG(ERROR) << "new scale primitive failed";
282 return nullptr;
283 }
284 scale_primitive->set_activation_type(scale_act_type_);
285 auto scale_primitive_c = scale_primitive->GetPrim();
286 MS_CHECK_TRUE_RET(scale_primitive_c != nullptr, nullptr);
287 if (INT_ADD_OVERFLOW_THRESHOLD(bias_tensor_->shape_c().size(), axis_offset, SIZE_MAX)) {
288 MS_LOG(ERROR) << "Add overflow: " << bias_tensor_->shape_c().size() << " + " << axis_offset;
289 return nullptr;
290 }
291 scale_primitive->set_axis(-(static_cast<int64_t>(bias_tensor_->shape_c().size() + axis_offset)));
292
293 // copy the format of add node to scale node
294 if (CopyNodeFormat(add_cnode, scale_primitive_c)) {
295 MS_LOG(WARNING) << "Copy original node format failed";
296 }
297
298 // create scale op
299 auto scale_node = func_graph->NewCNode(scale_primitive_c, {mul_input_anode, mul_const_anode_, add_const_anode_});
300 scale_node->set_abstract(add_cnode->abstract());
301 return scale_node;
302 }
303 } // namespace mindspore::opt
304