1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define USE_DEPRECATED_API
17 #include "tools/optimizer/fusion/mul_reduce_fusion.h"
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "tools/optimizer/common/gllo_utils.h"
27 #include "tools/lite_exporter/fetch_content.h"
28 #include "ops/fusion/mat_mul_fusion.h"
29 #include "ops/fusion/mul_fusion.h"
30 #include "ops/squeeze.h"
31 #include "ops/op_name.h"
32 #include "nnacl/op_base.h"
33
34 namespace mindspore {
35 namespace opt {
36 namespace {
37 constexpr int kReciprocalFirstIndex = -1;
38 constexpr int kReciprocalSecondIndex = -2;
39 } // namespace
40
Run(const FuncGraphPtr & func_graph)41 bool MulReduceFusion::Run(const FuncGraphPtr &func_graph) {
42 if (func_graph == nullptr) {
43 return false;
44 }
45 auto ret = preprocessor_.Run(func_graph);
46 if (ret == lite::RET_NOT_SUPPORT) {
47 return true;
48 }
49 if (ret != lite::RET_OK) {
50 return false;
51 }
52 auto &shape_container = preprocessor_.GetShapeContainer();
53 std::vector<CNodePtr> reduce_ops;
54 for (auto infos : shape_container) {
55 if (!utils::isa<CNode>(infos.first)) {
56 continue;
57 }
58 if (!CheckPrimitiveType(infos.first, prim::kPrimReduceFusion)) {
59 continue;
60 }
61 reduce_ops.push_back(infos.first->cast<CNodePtr>());
62 }
63 for (auto reduce_op : reduce_ops) {
64 ret = ProcessOp(func_graph, reduce_op);
65 if (ret != lite::RET_OK) {
66 MS_LOG(ERROR) << "mul-reduce fusion process failed.";
67 return false;
68 }
69 }
70 ret = PostProcess(func_graph);
71 if (ret != lite::RET_OK) {
72 MS_LOG(ERROR) << "mul-reduce fusion post-process failed.";
73 return false;
74 }
75 return true;
76 }
77
ProcessOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)78 int MulReduceFusion::ProcessOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
79 auto is_meet_cond = CheckBasicCond(func_graph, cnode);
80 if (!is_meet_cond) {
81 return lite::RET_OK;
82 }
83 bool need_post_mul = false;
84 if (reduce_mode_ == ReduceMode::Reduce_Mean) {
85 auto ret = ProcessGather();
86 if (ret == lite::RET_NOT_SUPPORT) {
87 need_post_mul = true;
88 } else if (ret != lite::RET_OK) {
89 MS_LOG(ERROR) << "Process Gather op failed.";
90 return lite::RET_ERROR;
91 }
92 }
93 if (!keep_dim_) {
94 auto ret = GenerateSqueeze(func_graph, cnode);
95 if (ret != lite::RET_OK) {
96 return lite::RET_ERROR;
97 }
98 }
99 if (need_post_mul) {
100 auto ret = GenerateMul(func_graph, cnode);
101 if (ret != lite::RET_OK) {
102 return lite::RET_ERROR;
103 }
104 }
105 auto ret = GenerateMatmul(func_graph, cnode);
106 if (ret != lite::RET_OK) {
107 return lite::RET_ERROR;
108 }
109 return lite::RET_OK;
110 }
111
ProcessGather()112 int MulReduceFusion::ProcessGather() {
113 MS_ASSERT(gather_.size() > C1NUM);
114 auto gather_table = gather_->input(1);
115 if (gather_table == nullptr || utils::isa<CNode>(gather_table)) {
116 return lite::RET_NOT_SUPPORT;
117 }
118 lite::DataInfo data_info;
119 auto ret = lite::FetchConstData(gather_, 1, converter::kFmkTypeMs, &data_info, false);
120 MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Fetch const data of gather failed.");
121 if (data_info.data_type_ != kNumberTypeFloat && data_info.data_type_ != kNumberTypeFloat32) {
122 return lite::RET_NOT_SUPPORT;
123 }
124 if (data_info.data_ptr_ == nullptr) {
125 return lite::RET_NOT_SUPPORT;
126 }
127 auto *float_data = static_cast<float *>(data_info.data_ptr_);
128 auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
129 for (int64_t i = 0; i < element_num; ++i) {
130 float_data[i] *= coeff_;
131 }
132 return lite::RET_OK;
133 }
134
PostProcess(const FuncGraphPtr & func_graph)135 int MulReduceFusion::PostProcess(const FuncGraphPtr &func_graph) {
136 MS_ASSERT(func_graph != nullptr);
137 if (squeeze_infos_.empty()) {
138 return lite::RET_OK;
139 }
140 std::set<CNodePtr> concat_ops;
141 auto manager = func_graph->manager();
142 MS_ASSERT(manager != nullptr);
143 auto &node_users = manager->node_users();
144 for (auto &squeeze : squeeze_infos_) {
145 auto &node_user = node_users[squeeze.first];
146 for (auto &user : node_user) {
147 auto node = user.first;
148 if (!utils::isa<CNode>(node)) {
149 continue;
150 }
151 auto cnode = node->cast<CNodePtr>();
152 if (CheckPrimitiveType(cnode, prim::kPrimConcat)) {
153 (void)concat_ops.insert(cnode);
154 }
155 }
156 }
157 for (auto &concat : concat_ops) {
158 auto ret = PostProcessSqueezeWithConcat(func_graph, concat);
159 if (ret != lite::RET_OK) {
160 MS_LOG(ERROR) << "mul-reduce-fusion's PostProcess failed.";
161 return lite::RET_ERROR;
162 }
163 }
164 return lite::RET_OK;
165 }
166
PostProcessSqueezeWithConcat(const FuncGraphPtr & func_graph,const CNodePtr & cnode)167 int MulReduceFusion::PostProcessSqueezeWithConcat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
168 MS_ASSERT(func_graph != nullptr);
169 MS_ASSERT(cnode != nullptr);
170 if (!CheckConcatOp(func_graph, cnode)) {
171 return lite::RET_OK;
172 }
173 auto manager = func_graph->manager();
174 MS_ASSERT(manager != nullptr);
175 for (int i = 1; i < static_cast<int>(cnode->size()); ++i) {
176 manager->SetEdge(cnode, i, cnode->input(i)->cast<CNodePtr>()->input(1));
177 }
178 auto concat_prim = GetCNodePrimitive(cnode);
179 MS_ASSERT(concat_prim != nullptr);
180 (void)concat_prim->AddAttr(ops::kAxis, MakeValue<int64_t>(concat_axis_));
181 auto &node_users = manager->node_users();
182 auto concat_users = node_users[cnode];
183 CNodePtr post_squeeze{nullptr};
184 for (auto &user : concat_users) {
185 if (CheckPrimitiveType(user.first, prim::kPrimReshape)) {
186 continue;
187 }
188 if (post_squeeze == nullptr) {
189 auto squeeze = std::make_shared<ops::Squeeze>();
190 MS_CHECK_TRUE_MSG(squeeze != nullptr, lite::RET_ERROR, "Squeeze create failed.");
191 squeeze->set_axis(std::vector<int64_t>{axis_});
192 auto squeeze_prim = squeeze->GetPrim();
193 MS_CHECK_TRUE_MSG(squeeze_prim != nullptr, lite::RET_ERROR, "Squeeze create failed.");
194 post_squeeze = func_graph->NewCNode(squeeze_prim, {cnode});
195 MS_CHECK_TRUE_MSG(post_squeeze != nullptr, lite::RET_ERROR, "Squeeze-cnode create failed.");
196 post_squeeze->set_fullname_with_scope(cnode->fullname_with_scope() + "/Squeeze");
197 }
198 manager->SetEdge(user.first, user.second, post_squeeze);
199 }
200 return lite::RET_OK;
201 }
202
GenerateMatmul(const FuncGraphPtr & func_graph,const CNodePtr & cnode)203 int MulReduceFusion::GenerateMatmul(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
204 MS_ASSERT(func_graph != nullptr);
205 MS_ASSERT(cnode != nullptr);
206 auto manager = func_graph->manager();
207 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is a nullptr.");
208 auto mul_op = cnode->input(1)->cast<CNodePtr>(); // which has been checked before.
209 if (exchange_) {
210 manager->SetEdge(cnode, 1, mul_op->input(kInputIndexTwo));
211 manager->SetEdge(cnode, kInputIndexTwo, mul_op->input(1));
212 } else {
213 manager->SetEdge(cnode, 1, mul_op->input(1));
214 manager->SetEdge(cnode, kInputIndexTwo, mul_op->input(kInputIndexTwo));
215 }
216 auto matmul_prim = std::make_shared<ops::MatMulFusion>();
217 MS_CHECK_TRUE_MSG(matmul_prim != nullptr, lite::RET_ERROR, "Matmul create failed.");
218 auto matmul_prim_c = matmul_prim->GetPrim();
219 MS_CHECK_TRUE_MSG(matmul_prim_c != nullptr, lite::RET_ERROR, "Matmul create failed.");
220 matmul_prim->set_transpose_a(transpose_a_);
221 matmul_prim->set_transpose_b(transpose_b_);
222 MS_ASSERT(cnode->input(0) != nullptr);
223 auto reduce_prim_carrier = cnode->input(0)->cast<ValueNodePtr>();
224 MS_ASSERT(reduce_prim_carrier != nullptr);
225 reduce_prim_carrier->set_value(matmul_prim_c);
226 return lite::RET_OK;
227 }
228
GenerateSqueeze(const FuncGraphPtr & func_graph,const CNodePtr & cnode)229 int MulReduceFusion::GenerateSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
230 MS_ASSERT(func_graph != nullptr);
231 MS_ASSERT(cnode != nullptr);
232 auto manager = func_graph->manager();
233 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is a nullptr.");
234 auto squeeze = std::make_shared<ops::Squeeze>();
235 MS_CHECK_TRUE_MSG(squeeze != nullptr, lite::RET_ERROR, "Squeeze create failed.");
236 squeeze->set_axis(std::vector<int64_t>{axis_});
237 auto squeeze_prim = squeeze->GetPrim();
238 MS_CHECK_TRUE_MSG(squeeze_prim != nullptr, lite::RET_ERROR, "Squeeze create failed.");
239 auto squeeze_cnode = func_graph->NewCNode(squeeze_prim, {cnode});
240 MS_CHECK_TRUE_MSG(squeeze_cnode != nullptr, lite::RET_ERROR, "Squeeze-cnode create failed.");
241 auto mul_op = cnode->input(1);
242 MS_ASSERT(mul_op != nullptr);
243 squeeze_cnode->set_fullname_with_scope(mul_op->fullname_with_scope() + "/Squeeze");
244 auto success = manager->Replace(cnode, squeeze_cnode);
245 MS_CHECK_TRUE_MSG(success, lite::RET_ERROR, "Replace old node failed.");
246 auto &shape_infos = preprocessor_.GetShapeContainer();
247 MS_ASSERT(shape_infos.find(mul_op) != shape_infos.end());
248 auto &out_shape_infos = shape_infos.at(mul_op).second;
249 MS_ASSERT(!out_shape_infos.empty());
250 squeeze_infos_[squeeze_cnode] = std::make_pair(axis_, out_shape_infos.front().size() - 1);
251 return lite::RET_OK;
252 }
253
GenerateMul(const FuncGraphPtr & func_graph,const CNodePtr & cnode)254 int MulReduceFusion::GenerateMul(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
255 MS_ASSERT(func_graph != nullptr);
256 MS_ASSERT(cnode != nullptr);
257 if (coeff_ == 1.0f) {
258 return lite::RET_OK;
259 }
260 auto manager = func_graph->manager();
261 MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is a nullptr.");
262 auto mul = std::make_shared<ops::MulFusion>();
263 MS_CHECK_TRUE_MSG(mul != nullptr, lite::RET_ERROR, "Mul create failed.");
264 auto mul_prim = mul->GetPrim();
265 MS_CHECK_TRUE_MSG(mul_prim != nullptr, lite::RET_ERROR, "Mul create failed.");
266 auto old_mul_op = cnode->input(1);
267 MS_ASSERT(old_mul_op != nullptr);
268 auto second_input_node =
269 BuildFloatValueParameterNode(func_graph, coeff_, old_mul_op->fullname_with_scope() + "/scale");
270 MS_CHECK_TRUE_MSG(second_input_node != nullptr, lite::RET_ERROR, "Mul second-input create failed.");
271 auto mul_cnode = func_graph->NewCNode(mul_prim, {cnode, second_input_node});
272 MS_CHECK_TRUE_MSG(mul_cnode != nullptr, lite::RET_ERROR, "Mul-cnode create failed.");
273 mul_cnode->set_fullname_with_scope(old_mul_op->fullname_with_scope());
274 auto success = manager->Replace(cnode, mul_cnode);
275 MS_CHECK_TRUE_MSG(success, lite::RET_ERROR, "Replace old node failed.");
276 return lite::RET_OK;
277 }
278
CheckBasicCond(const FuncGraphPtr & func_graph,const CNodePtr & cnode)279 bool MulReduceFusion::CheckBasicCond(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
280 MS_ASSERT(cnode != nullptr);
281 if (cnode->size() < kInputSizeThree) {
282 return false;
283 }
284 if (IsMarkedTrainOp(cnode)) {
285 return false;
286 }
287 auto prim = GetCNodePrimitive(cnode);
288 MS_ASSERT(prim != nullptr);
289 bool is_to_end = prim->GetAttr(ops::kReduceToEnd) != nullptr && GetValue<bool>(prim->GetAttr(ops::kReduceToEnd));
290 if (is_to_end) {
291 return false;
292 }
293 keep_dim_ = prim->GetAttr(ops::kKeepDims) != nullptr && GetValue<bool>(prim->GetAttr(ops::kKeepDims));
294 auto mode_attr = prim->GetAttr(ops::kMode);
295 if (mode_attr == nullptr) {
296 return false;
297 }
298 reduce_mode_ = static_cast<int>(GetValue<int64_t>(mode_attr));
299 if (reduce_mode_ != ReduceMode::Reduce_Sum && reduce_mode_ != ReduceMode::Reduce_Mean) {
300 return false;
301 }
302 auto first_input = cnode->input(1);
303 if (!utils::isa<CNode>(first_input)) {
304 return false;
305 }
306 if (!CheckPrimitiveType(first_input, prim::kPrimMulFusion)) {
307 return false;
308 }
309 if (IsMarkedTrainOp(first_input->cast<CNodePtr>())) {
310 return false;
311 }
312 auto mul_prim = GetCNodePrimitive(first_input);
313 MS_ASSERT(mul_prim != nullptr);
314 auto act_type = mul_prim->GetAttr(ops::kActivationType) == nullptr
315 ? ActivationType::NO_ACTIVATION
316 : GetValue<int64_t>(mul_prim->GetAttr(ops::kActivationType));
317 if (act_type != ActivationType::NO_ACTIVATION) {
318 return false;
319 }
320 if (IsMultiOutputTensors(func_graph, first_input)) {
321 return false;
322 }
323 bool is_axis_meet = CheckAxisCond(cnode);
324 if (!is_axis_meet) {
325 return false;
326 }
327 bool is_shape_meet = CheckShapeCond(cnode);
328 if (!is_shape_meet) {
329 return false;
330 }
331 return CheckGatherOp(func_graph, cnode);
332 }
333
CheckAxisCond(const CNodePtr & cnode)334 bool MulReduceFusion::CheckAxisCond(const CNodePtr &cnode) {
335 MS_ASSERT(cnode != nullptr);
336 auto &shape_container = preprocessor_.GetShapeContainer();
337 auto first_input = cnode->input(1);
338 if (shape_container.find(first_input) == shape_container.end()) {
339 return false;
340 }
341 if (shape_container.at(first_input).second.empty()) {
342 return false;
343 }
344 auto in_shape = shape_container.at(first_input).second.front();
345 auto second_input = cnode->input(kInputIndexTwo);
346 if (second_input == nullptr || utils::isa<CNode>(second_input)) {
347 return false;
348 }
349 lite::DataInfo data_info;
350 auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
351 MS_CHECK_TRUE_MSG(ret == lite::RET_OK, false, "Fetch reduceOp's axis failed.");
352 auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
353 if (data_info.data_ptr_ == nullptr || element_num != 1) {
354 return false;
355 }
356 if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
357 axis_ = *(static_cast<int *>(data_info.data_ptr_));
358 } else if (data_info.data_type_ == kNumberTypeInt64) {
359 axis_ = static_cast<int>(*(static_cast<int64_t *>(data_info.data_ptr_)));
360 } else {
361 return false;
362 }
363 if (axis_ > 0) {
364 axis_ -= static_cast<int>(in_shape.size());
365 }
366 if (axis_ != kReciprocalFirstIndex && axis_ != kReciprocalSecondIndex) {
367 return false;
368 }
369 return true;
370 }
371
CheckShapeCond(const CNodePtr & cnode)372 bool MulReduceFusion::CheckShapeCond(const CNodePtr &cnode) {
373 MS_ASSERT(cnode != nullptr);
374 auto &shape_container = preprocessor_.GetShapeContainer();
375 auto first_input = cnode->input(1);
376 if (shape_container.find(first_input) == shape_container.end()) {
377 return false;
378 }
379 if (shape_container.at(first_input).first.size() != kInputSizeTwo) {
380 return false;
381 }
382 auto mul_in0_shape = shape_container.at(first_input).first.front();
383 auto mul_in1_shape = shape_container.at(first_input).first.back();
384 if (mul_in0_shape.size() < kInputSizeTwo || mul_in1_shape.size() < kInputSizeTwo) {
385 return false;
386 }
387 if (mul_in0_shape.back() <= 0 || mul_in0_shape[mul_in0_shape.size() - C2NUM] <= 0 || mul_in1_shape.back() <= 0 ||
388 mul_in1_shape[mul_in1_shape.size() - C2NUM] <= 0) {
389 return false;
390 }
391 if (axis_ == kReciprocalFirstIndex) {
392 if (mul_in0_shape.back() != mul_in1_shape.back() ||
393 (mul_in0_shape[mul_in0_shape.size() - C2NUM] != 1 && mul_in1_shape[mul_in1_shape.size() - C2NUM] != 1)) {
394 return false;
395 }
396 exchange_ = mul_in1_shape[mul_in1_shape.size() - C2NUM] != 1;
397 transpose_a_ = false;
398 transpose_b_ = true;
399 MS_ASSERT(mul_in0_shape.back() != 0);
400 coeff_ = 1.0f / static_cast<float>(mul_in0_shape.back());
401 return true;
402 }
403 if (axis_ == kReciprocalSecondIndex) {
404 if (mul_in0_shape[mul_in0_shape.size() - C2NUM] != mul_in1_shape[mul_in1_shape.size() - C2NUM] ||
405 (mul_in0_shape.back() != 1 && mul_in1_shape.back() != 1)) {
406 return false;
407 }
408 exchange_ = mul_in0_shape.back() != 1;
409 transpose_a_ = true;
410 transpose_b_ = false;
411 MS_ASSERT(mul_in0_shape[mul_in0_shape.size() - C2NUM] != 0);
412 coeff_ = 1.0f / static_cast<float>(mul_in0_shape[mul_in0_shape.size() - C2NUM]);
413 return true;
414 }
415 return false;
416 }
417
CheckGatherOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)418 bool MulReduceFusion::CheckGatherOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
419 MS_ASSERT(cnode != nullptr);
420 if (reduce_mode_ == ReduceMode::Reduce_Sum) {
421 return true;
422 }
423 if (reduce_mode_ != ReduceMode::Reduce_Mean) {
424 return false;
425 }
426 auto mul_op = cnode->input(1);
427 if (!utils::isa<CNode>(mul_op)) {
428 return false;
429 }
430 auto mul_op_cnode = mul_op->cast<CNodePtr>();
431 for (size_t i = 1; i < mul_op_cnode->size(); ++i) {
432 if (!utils::isa<CNode>(mul_op_cnode->input(i))) {
433 continue;
434 }
435 if (CheckPrimitiveType(mul_op_cnode->input(i), prim::kPrimGather)) {
436 gather_ = mul_op_cnode->input(i)->cast<CNodePtr>();
437 break;
438 }
439 }
440 if (gather_ == nullptr) {
441 return false;
442 }
443 if (IsMarkedTrainOp(gather_)) {
444 return false;
445 }
446 if (IsMultiOutputTensors(func_graph, gather_)) {
447 return false;
448 }
449 return true;
450 }
451
CheckConcatOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)452 bool MulReduceFusion::CheckConcatOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
453 MS_ASSERT(cnode != nullptr);
454 int axis{0};
455 int out_dims{0};
456 for (size_t i = 1; i < cnode->size(); ++i) {
457 auto in_node = cnode->input(i);
458 if (!utils::isa<CNode>(in_node)) {
459 return false;
460 }
461 auto in_cnode = in_node->cast<CNodePtr>();
462 if (squeeze_infos_.find(in_cnode) == squeeze_infos_.end()) {
463 return false;
464 }
465 if (IsMultiOutputTensors(func_graph, in_node)) {
466 return false;
467 }
468 if (i == 1) {
469 axis = squeeze_infos_[in_cnode].first;
470 out_dims = squeeze_infos_[in_cnode].second;
471 } else {
472 if (squeeze_infos_[in_cnode].first != axis || squeeze_infos_[in_cnode].second != out_dims) {
473 return false;
474 }
475 }
476 }
477 auto concat_prim = GetCNodePrimitive(cnode);
478 MS_CHECK_TRUE_RET(concat_prim != nullptr, false);
479 concat_axis_ = concat_prim->GetAttr(ops::kAxis) == nullptr
480 ? 0
481 : static_cast<int>(GetValue<int64_t>(concat_prim->GetAttr(ops::kAxis)));
482 axis = axis < 0 ? axis + out_dims + 1 : axis;
483 MS_CHECK_TRUE_RET(axis >= 0 && axis <= out_dims, false);
484 concat_axis_ = concat_axis_ < 0 ? concat_axis_ + out_dims : concat_axis_;
485 MS_CHECK_TRUE_RET(concat_axis_ >= 0 && concat_axis_ < out_dims, false);
486 if (concat_axis_ >= axis) {
487 ++concat_axis_;
488 }
489 return true;
490 }
491 } // namespace opt
492 } // namespace mindspore
493