• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define USE_DEPRECATED_API
17 #include "tools/optimizer/fusion/mul_reduce_fusion.h"
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "tools/optimizer/common/gllo_utils.h"
27 #include "tools/lite_exporter/fetch_content.h"
28 #include "ops/fusion/mat_mul_fusion.h"
29 #include "ops/fusion/mul_fusion.h"
30 #include "ops/squeeze.h"
31 #include "ops/op_name.h"
32 #include "nnacl/op_base.h"
33 
34 namespace mindspore {
35 namespace opt {
36 namespace {
37 constexpr int kReciprocalFirstIndex = -1;
38 constexpr int kReciprocalSecondIndex = -2;
39 }  // namespace
40 
Run(const FuncGraphPtr & func_graph)41 bool MulReduceFusion::Run(const FuncGraphPtr &func_graph) {
42   if (func_graph == nullptr) {
43     return false;
44   }
45   auto ret = preprocessor_.Run(func_graph);
46   if (ret == lite::RET_NOT_SUPPORT) {
47     return true;
48   }
49   if (ret != lite::RET_OK) {
50     return false;
51   }
52   auto &shape_container = preprocessor_.GetShapeContainer();
53   std::vector<CNodePtr> reduce_ops;
54   for (auto infos : shape_container) {
55     if (!utils::isa<CNode>(infos.first)) {
56       continue;
57     }
58     if (!CheckPrimitiveType(infos.first, prim::kPrimReduceFusion)) {
59       continue;
60     }
61     reduce_ops.push_back(infos.first->cast<CNodePtr>());
62   }
63   for (auto reduce_op : reduce_ops) {
64     ret = ProcessOp(func_graph, reduce_op);
65     if (ret != lite::RET_OK) {
66       MS_LOG(ERROR) << "mul-reduce fusion process failed.";
67       return false;
68     }
69   }
70   ret = PostProcess(func_graph);
71   if (ret != lite::RET_OK) {
72     MS_LOG(ERROR) << "mul-reduce fusion post-process failed.";
73     return false;
74   }
75   return true;
76 }
77 
ProcessOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)78 int MulReduceFusion::ProcessOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
79   auto is_meet_cond = CheckBasicCond(func_graph, cnode);
80   if (!is_meet_cond) {
81     return lite::RET_OK;
82   }
83   bool need_post_mul = false;
84   if (reduce_mode_ == ReduceMode::Reduce_Mean) {
85     auto ret = ProcessGather();
86     if (ret == lite::RET_NOT_SUPPORT) {
87       need_post_mul = true;
88     } else if (ret != lite::RET_OK) {
89       MS_LOG(ERROR) << "Process Gather op failed.";
90       return lite::RET_ERROR;
91     }
92   }
93   if (!keep_dim_) {
94     auto ret = GenerateSqueeze(func_graph, cnode);
95     if (ret != lite::RET_OK) {
96       return lite::RET_ERROR;
97     }
98   }
99   if (need_post_mul) {
100     auto ret = GenerateMul(func_graph, cnode);
101     if (ret != lite::RET_OK) {
102       return lite::RET_ERROR;
103     }
104   }
105   auto ret = GenerateMatmul(func_graph, cnode);
106   if (ret != lite::RET_OK) {
107     return lite::RET_ERROR;
108   }
109   return lite::RET_OK;
110 }
111 
ProcessGather()112 int MulReduceFusion::ProcessGather() {
113   MS_ASSERT(gather_.size() > C1NUM);
114   auto gather_table = gather_->input(1);
115   if (gather_table == nullptr || utils::isa<CNode>(gather_table)) {
116     return lite::RET_NOT_SUPPORT;
117   }
118   lite::DataInfo data_info;
119   auto ret = lite::FetchConstData(gather_, 1, converter::kFmkTypeMs, &data_info, false);
120   MS_CHECK_TRUE_MSG(ret == lite::RET_OK, lite::RET_ERROR, "Fetch const data of gather failed.");
121   if (data_info.data_type_ != kNumberTypeFloat && data_info.data_type_ != kNumberTypeFloat32) {
122     return lite::RET_NOT_SUPPORT;
123   }
124   if (data_info.data_ptr_ == nullptr) {
125     return lite::RET_NOT_SUPPORT;
126   }
127   auto *float_data = static_cast<float *>(data_info.data_ptr_);
128   auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
129   for (int64_t i = 0; i < element_num; ++i) {
130     float_data[i] *= coeff_;
131   }
132   return lite::RET_OK;
133 }
134 
PostProcess(const FuncGraphPtr & func_graph)135 int MulReduceFusion::PostProcess(const FuncGraphPtr &func_graph) {
136   MS_ASSERT(func_graph != nullptr);
137   if (squeeze_infos_.empty()) {
138     return lite::RET_OK;
139   }
140   std::set<CNodePtr> concat_ops;
141   auto manager = func_graph->manager();
142   MS_ASSERT(manager != nullptr);
143   auto &node_users = manager->node_users();
144   for (auto &squeeze : squeeze_infos_) {
145     auto &node_user = node_users[squeeze.first];
146     for (auto &user : node_user) {
147       auto node = user.first;
148       if (!utils::isa<CNode>(node)) {
149         continue;
150       }
151       auto cnode = node->cast<CNodePtr>();
152       if (CheckPrimitiveType(cnode, prim::kPrimConcat)) {
153         (void)concat_ops.insert(cnode);
154       }
155     }
156   }
157   for (auto &concat : concat_ops) {
158     auto ret = PostProcessSqueezeWithConcat(func_graph, concat);
159     if (ret != lite::RET_OK) {
160       MS_LOG(ERROR) << "mul-reduce-fusion's PostProcess failed.";
161       return lite::RET_ERROR;
162     }
163   }
164   return lite::RET_OK;
165 }
166 
PostProcessSqueezeWithConcat(const FuncGraphPtr & func_graph,const CNodePtr & cnode)167 int MulReduceFusion::PostProcessSqueezeWithConcat(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
168   MS_ASSERT(func_graph != nullptr);
169   MS_ASSERT(cnode != nullptr);
170   if (!CheckConcatOp(func_graph, cnode)) {
171     return lite::RET_OK;
172   }
173   auto manager = func_graph->manager();
174   MS_ASSERT(manager != nullptr);
175   for (int i = 1; i < static_cast<int>(cnode->size()); ++i) {
176     manager->SetEdge(cnode, i, cnode->input(i)->cast<CNodePtr>()->input(1));
177   }
178   auto concat_prim = GetCNodePrimitive(cnode);
179   MS_ASSERT(concat_prim != nullptr);
180   (void)concat_prim->AddAttr(ops::kAxis, MakeValue<int64_t>(concat_axis_));
181   auto &node_users = manager->node_users();
182   auto concat_users = node_users[cnode];
183   CNodePtr post_squeeze{nullptr};
184   for (auto &user : concat_users) {
185     if (CheckPrimitiveType(user.first, prim::kPrimReshape)) {
186       continue;
187     }
188     if (post_squeeze == nullptr) {
189       auto squeeze = std::make_shared<ops::Squeeze>();
190       MS_CHECK_TRUE_MSG(squeeze != nullptr, lite::RET_ERROR, "Squeeze create failed.");
191       squeeze->set_axis(std::vector<int64_t>{axis_});
192       auto squeeze_prim = squeeze->GetPrim();
193       MS_CHECK_TRUE_MSG(squeeze_prim != nullptr, lite::RET_ERROR, "Squeeze create failed.");
194       post_squeeze = func_graph->NewCNode(squeeze_prim, {cnode});
195       MS_CHECK_TRUE_MSG(post_squeeze != nullptr, lite::RET_ERROR, "Squeeze-cnode create failed.");
196       post_squeeze->set_fullname_with_scope(cnode->fullname_with_scope() + "/Squeeze");
197     }
198     manager->SetEdge(user.first, user.second, post_squeeze);
199   }
200   return lite::RET_OK;
201 }
202 
GenerateMatmul(const FuncGraphPtr & func_graph,const CNodePtr & cnode)203 int MulReduceFusion::GenerateMatmul(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
204   MS_ASSERT(func_graph != nullptr);
205   MS_ASSERT(cnode != nullptr);
206   auto manager = func_graph->manager();
207   MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is a nullptr.");
208   auto mul_op = cnode->input(1)->cast<CNodePtr>();  // which has been checked before.
209   if (exchange_) {
210     manager->SetEdge(cnode, 1, mul_op->input(kInputIndexTwo));
211     manager->SetEdge(cnode, kInputIndexTwo, mul_op->input(1));
212   } else {
213     manager->SetEdge(cnode, 1, mul_op->input(1));
214     manager->SetEdge(cnode, kInputIndexTwo, mul_op->input(kInputIndexTwo));
215   }
216   auto matmul_prim = std::make_shared<ops::MatMulFusion>();
217   MS_CHECK_TRUE_MSG(matmul_prim != nullptr, lite::RET_ERROR, "Matmul create failed.");
218   auto matmul_prim_c = matmul_prim->GetPrim();
219   MS_CHECK_TRUE_MSG(matmul_prim_c != nullptr, lite::RET_ERROR, "Matmul create failed.");
220   matmul_prim->set_transpose_a(transpose_a_);
221   matmul_prim->set_transpose_b(transpose_b_);
222   MS_ASSERT(cnode->input(0) != nullptr);
223   auto reduce_prim_carrier = cnode->input(0)->cast<ValueNodePtr>();
224   MS_ASSERT(reduce_prim_carrier != nullptr);
225   reduce_prim_carrier->set_value(matmul_prim_c);
226   return lite::RET_OK;
227 }
228 
GenerateSqueeze(const FuncGraphPtr & func_graph,const CNodePtr & cnode)229 int MulReduceFusion::GenerateSqueeze(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
230   MS_ASSERT(func_graph != nullptr);
231   MS_ASSERT(cnode != nullptr);
232   auto manager = func_graph->manager();
233   MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is a nullptr.");
234   auto squeeze = std::make_shared<ops::Squeeze>();
235   MS_CHECK_TRUE_MSG(squeeze != nullptr, lite::RET_ERROR, "Squeeze create failed.");
236   squeeze->set_axis(std::vector<int64_t>{axis_});
237   auto squeeze_prim = squeeze->GetPrim();
238   MS_CHECK_TRUE_MSG(squeeze_prim != nullptr, lite::RET_ERROR, "Squeeze create failed.");
239   auto squeeze_cnode = func_graph->NewCNode(squeeze_prim, {cnode});
240   MS_CHECK_TRUE_MSG(squeeze_cnode != nullptr, lite::RET_ERROR, "Squeeze-cnode create failed.");
241   auto mul_op = cnode->input(1);
242   MS_ASSERT(mul_op != nullptr);
243   squeeze_cnode->set_fullname_with_scope(mul_op->fullname_with_scope() + "/Squeeze");
244   auto success = manager->Replace(cnode, squeeze_cnode);
245   MS_CHECK_TRUE_MSG(success, lite::RET_ERROR, "Replace old node failed.");
246   auto &shape_infos = preprocessor_.GetShapeContainer();
247   MS_ASSERT(shape_infos.find(mul_op) != shape_infos.end());
248   auto &out_shape_infos = shape_infos.at(mul_op).second;
249   MS_ASSERT(!out_shape_infos.empty());
250   squeeze_infos_[squeeze_cnode] = std::make_pair(axis_, out_shape_infos.front().size() - 1);
251   return lite::RET_OK;
252 }
253 
GenerateMul(const FuncGraphPtr & func_graph,const CNodePtr & cnode)254 int MulReduceFusion::GenerateMul(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
255   MS_ASSERT(func_graph != nullptr);
256   MS_ASSERT(cnode != nullptr);
257   if (coeff_ == 1.0f) {
258     return lite::RET_OK;
259   }
260   auto manager = func_graph->manager();
261   MS_CHECK_TRUE_MSG(manager != nullptr, lite::RET_ERROR, "Manager is a nullptr.");
262   auto mul = std::make_shared<ops::MulFusion>();
263   MS_CHECK_TRUE_MSG(mul != nullptr, lite::RET_ERROR, "Mul create failed.");
264   auto mul_prim = mul->GetPrim();
265   MS_CHECK_TRUE_MSG(mul_prim != nullptr, lite::RET_ERROR, "Mul create failed.");
266   auto old_mul_op = cnode->input(1);
267   MS_ASSERT(old_mul_op != nullptr);
268   auto second_input_node =
269     BuildFloatValueParameterNode(func_graph, coeff_, old_mul_op->fullname_with_scope() + "/scale");
270   MS_CHECK_TRUE_MSG(second_input_node != nullptr, lite::RET_ERROR, "Mul second-input create failed.");
271   auto mul_cnode = func_graph->NewCNode(mul_prim, {cnode, second_input_node});
272   MS_CHECK_TRUE_MSG(mul_cnode != nullptr, lite::RET_ERROR, "Mul-cnode create failed.");
273   mul_cnode->set_fullname_with_scope(old_mul_op->fullname_with_scope());
274   auto success = manager->Replace(cnode, mul_cnode);
275   MS_CHECK_TRUE_MSG(success, lite::RET_ERROR, "Replace old node failed.");
276   return lite::RET_OK;
277 }
278 
CheckBasicCond(const FuncGraphPtr & func_graph,const CNodePtr & cnode)279 bool MulReduceFusion::CheckBasicCond(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
280   MS_ASSERT(cnode != nullptr);
281   if (cnode->size() < kInputSizeThree) {
282     return false;
283   }
284   if (IsMarkedTrainOp(cnode)) {
285     return false;
286   }
287   auto prim = GetCNodePrimitive(cnode);
288   MS_ASSERT(prim != nullptr);
289   bool is_to_end = prim->GetAttr(ops::kReduceToEnd) != nullptr && GetValue<bool>(prim->GetAttr(ops::kReduceToEnd));
290   if (is_to_end) {
291     return false;
292   }
293   keep_dim_ = prim->GetAttr(ops::kKeepDims) != nullptr && GetValue<bool>(prim->GetAttr(ops::kKeepDims));
294   auto mode_attr = prim->GetAttr(ops::kMode);
295   if (mode_attr == nullptr) {
296     return false;
297   }
298   reduce_mode_ = static_cast<int>(GetValue<int64_t>(mode_attr));
299   if (reduce_mode_ != ReduceMode::Reduce_Sum && reduce_mode_ != ReduceMode::Reduce_Mean) {
300     return false;
301   }
302   auto first_input = cnode->input(1);
303   if (!utils::isa<CNode>(first_input)) {
304     return false;
305   }
306   if (!CheckPrimitiveType(first_input, prim::kPrimMulFusion)) {
307     return false;
308   }
309   if (IsMarkedTrainOp(first_input->cast<CNodePtr>())) {
310     return false;
311   }
312   auto mul_prim = GetCNodePrimitive(first_input);
313   MS_ASSERT(mul_prim != nullptr);
314   auto act_type = mul_prim->GetAttr(ops::kActivationType) == nullptr
315                     ? ActivationType::NO_ACTIVATION
316                     : GetValue<int64_t>(mul_prim->GetAttr(ops::kActivationType));
317   if (act_type != ActivationType::NO_ACTIVATION) {
318     return false;
319   }
320   if (IsMultiOutputTensors(func_graph, first_input)) {
321     return false;
322   }
323   bool is_axis_meet = CheckAxisCond(cnode);
324   if (!is_axis_meet) {
325     return false;
326   }
327   bool is_shape_meet = CheckShapeCond(cnode);
328   if (!is_shape_meet) {
329     return false;
330   }
331   return CheckGatherOp(func_graph, cnode);
332 }
333 
CheckAxisCond(const CNodePtr & cnode)334 bool MulReduceFusion::CheckAxisCond(const CNodePtr &cnode) {
335   MS_ASSERT(cnode != nullptr);
336   auto &shape_container = preprocessor_.GetShapeContainer();
337   auto first_input = cnode->input(1);
338   if (shape_container.find(first_input) == shape_container.end()) {
339     return false;
340   }
341   if (shape_container.at(first_input).second.empty()) {
342     return false;
343   }
344   auto in_shape = shape_container.at(first_input).second.front();
345   auto second_input = cnode->input(kInputIndexTwo);
346   if (second_input == nullptr || utils::isa<CNode>(second_input)) {
347     return false;
348   }
349   lite::DataInfo data_info;
350   auto ret = lite::FetchConstData(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, false);
351   MS_CHECK_TRUE_MSG(ret == lite::RET_OK, false, "Fetch reduceOp's axis failed.");
352   auto element_num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1L, std::multiplies<int64_t>());
353   if (data_info.data_ptr_ == nullptr || element_num != 1) {
354     return false;
355   }
356   if (data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32) {
357     axis_ = *(static_cast<int *>(data_info.data_ptr_));
358   } else if (data_info.data_type_ == kNumberTypeInt64) {
359     axis_ = static_cast<int>(*(static_cast<int64_t *>(data_info.data_ptr_)));
360   } else {
361     return false;
362   }
363   if (axis_ > 0) {
364     axis_ -= static_cast<int>(in_shape.size());
365   }
366   if (axis_ != kReciprocalFirstIndex && axis_ != kReciprocalSecondIndex) {
367     return false;
368   }
369   return true;
370 }
371 
CheckShapeCond(const CNodePtr & cnode)372 bool MulReduceFusion::CheckShapeCond(const CNodePtr &cnode) {
373   MS_ASSERT(cnode != nullptr);
374   auto &shape_container = preprocessor_.GetShapeContainer();
375   auto first_input = cnode->input(1);
376   if (shape_container.find(first_input) == shape_container.end()) {
377     return false;
378   }
379   if (shape_container.at(first_input).first.size() != kInputSizeTwo) {
380     return false;
381   }
382   auto mul_in0_shape = shape_container.at(first_input).first.front();
383   auto mul_in1_shape = shape_container.at(first_input).first.back();
384   if (mul_in0_shape.size() < kInputSizeTwo || mul_in1_shape.size() < kInputSizeTwo) {
385     return false;
386   }
387   if (mul_in0_shape.back() <= 0 || mul_in0_shape[mul_in0_shape.size() - C2NUM] <= 0 || mul_in1_shape.back() <= 0 ||
388       mul_in1_shape[mul_in1_shape.size() - C2NUM] <= 0) {
389     return false;
390   }
391   if (axis_ == kReciprocalFirstIndex) {
392     if (mul_in0_shape.back() != mul_in1_shape.back() ||
393         (mul_in0_shape[mul_in0_shape.size() - C2NUM] != 1 && mul_in1_shape[mul_in1_shape.size() - C2NUM] != 1)) {
394       return false;
395     }
396     exchange_ = mul_in1_shape[mul_in1_shape.size() - C2NUM] != 1;
397     transpose_a_ = false;
398     transpose_b_ = true;
399     MS_ASSERT(mul_in0_shape.back() != 0);
400     coeff_ = 1.0f / static_cast<float>(mul_in0_shape.back());
401     return true;
402   }
403   if (axis_ == kReciprocalSecondIndex) {
404     if (mul_in0_shape[mul_in0_shape.size() - C2NUM] != mul_in1_shape[mul_in1_shape.size() - C2NUM] ||
405         (mul_in0_shape.back() != 1 && mul_in1_shape.back() != 1)) {
406       return false;
407     }
408     exchange_ = mul_in0_shape.back() != 1;
409     transpose_a_ = true;
410     transpose_b_ = false;
411     MS_ASSERT(mul_in0_shape[mul_in0_shape.size() - C2NUM] != 0);
412     coeff_ = 1.0f / static_cast<float>(mul_in0_shape[mul_in0_shape.size() - C2NUM]);
413     return true;
414   }
415   return false;
416 }
417 
CheckGatherOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)418 bool MulReduceFusion::CheckGatherOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
419   MS_ASSERT(cnode != nullptr);
420   if (reduce_mode_ == ReduceMode::Reduce_Sum) {
421     return true;
422   }
423   if (reduce_mode_ != ReduceMode::Reduce_Mean) {
424     return false;
425   }
426   auto mul_op = cnode->input(1);
427   if (!utils::isa<CNode>(mul_op)) {
428     return false;
429   }
430   auto mul_op_cnode = mul_op->cast<CNodePtr>();
431   for (size_t i = 1; i < mul_op_cnode->size(); ++i) {
432     if (!utils::isa<CNode>(mul_op_cnode->input(i))) {
433       continue;
434     }
435     if (CheckPrimitiveType(mul_op_cnode->input(i), prim::kPrimGather)) {
436       gather_ = mul_op_cnode->input(i)->cast<CNodePtr>();
437       break;
438     }
439   }
440   if (gather_ == nullptr) {
441     return false;
442   }
443   if (IsMarkedTrainOp(gather_)) {
444     return false;
445   }
446   if (IsMultiOutputTensors(func_graph, gather_)) {
447     return false;
448   }
449   return true;
450 }
451 
CheckConcatOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)452 bool MulReduceFusion::CheckConcatOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
453   MS_ASSERT(cnode != nullptr);
454   int axis{0};
455   int out_dims{0};
456   for (size_t i = 1; i < cnode->size(); ++i) {
457     auto in_node = cnode->input(i);
458     if (!utils::isa<CNode>(in_node)) {
459       return false;
460     }
461     auto in_cnode = in_node->cast<CNodePtr>();
462     if (squeeze_infos_.find(in_cnode) == squeeze_infos_.end()) {
463       return false;
464     }
465     if (IsMultiOutputTensors(func_graph, in_node)) {
466       return false;
467     }
468     if (i == 1) {
469       axis = squeeze_infos_[in_cnode].first;
470       out_dims = squeeze_infos_[in_cnode].second;
471     } else {
472       if (squeeze_infos_[in_cnode].first != axis || squeeze_infos_[in_cnode].second != out_dims) {
473         return false;
474       }
475     }
476   }
477   auto concat_prim = GetCNodePrimitive(cnode);
478   MS_CHECK_TRUE_RET(concat_prim != nullptr, false);
479   concat_axis_ = concat_prim->GetAttr(ops::kAxis) == nullptr
480                    ? 0
481                    : static_cast<int>(GetValue<int64_t>(concat_prim->GetAttr(ops::kAxis)));
482   axis = axis < 0 ? axis + out_dims + 1 : axis;
483   MS_CHECK_TRUE_RET(axis >= 0 && axis <= out_dims, false);
484   concat_axis_ = concat_axis_ < 0 ? concat_axis_ + out_dims : concat_axis_;
485   MS_CHECK_TRUE_RET(concat_axis_ >= 0 && concat_axis_ < out_dims, false);
486   if (concat_axis_ >= axis) {
487     ++concat_axis_;
488   }
489   return true;
490 }
491 }  // namespace opt
492 }  // namespace mindspore
493