• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/scale_scale_fusion.h"
19 #include <functional>
20 #include <memory>
21 #include "mindspore/core/ops/lite_ops.h"
22 #include "tools/converter/quantizer/quant_param_holder.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "tools/common/tensor_util.h"
25 #include "ops/fusion/scale_fusion.h"
26 #include "securec/include/securec.h"
27 #include "nnacl/op_base.h"
28 #include "ops/op_utils.h"
29 
30 namespace mindspore::opt {
31 namespace {
32 constexpr size_t kScaleWeightIndex = 2;
33 constexpr size_t kScaleBiasIndex = 3;
34 constexpr size_t kScaleNoBiasLen = 3;
35 constexpr size_t kScaleWithBiasLen = 4;
36 }  // namespace
37 
DefinePattern() const38 const BaseRef ScaleScaleFusion::DefinePattern() const {
39   auto is_scale_up = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
40   MS_CHECK_TRUE_RET(is_scale_up != nullptr, {});
41   auto is_param = std::make_shared<CondVar>(IsParamNode);
42   MS_CHECK_TRUE_RET(is_param != nullptr, {});
43   auto is_seq_var = std::make_shared<SeqVar>();
44   MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
45   auto is_scale_down = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
46   MS_CHECK_TRUE_RET(is_scale_down != nullptr, {});
47   return VectorRef({is_scale_down, is_scale_up, is_param, is_seq_var});
48 }
49 
CheckScaleNode(const CNodePtr & scale_cnode) const50 bool ScaleScaleFusion::CheckScaleNode(const CNodePtr &scale_cnode) const {
51   MS_ASSERT(scale_cnode != nullptr);
52   if (IsMarkedTrainOp(scale_cnode)) {
53     return false;
54   }
55   MS_CHECK_TRUE_RET(scale_cnode->size() >= kScaleNoBiasLen, false);
56   auto scale_prim = ops::GetOperator<ops::ScaleFusion>(scale_cnode->input(FIRST_INPUT));
57   MS_CHECK_TRUE_RET(scale_prim != nullptr, false);
58   auto scale_prim_c = scale_prim->GetPrim();
59   MS_CHECK_TRUE_RET(scale_prim_c != nullptr, false);
60   auto quant_attr = scale_prim_c->GetAttr("quant_params");
61   if (quant_attr != nullptr) {
62     auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
63     MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
64     auto quant_params = quant_param_holder->get_input_quant_params();
65     bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> &params) {
66       return !params.empty() && params.front().inited;
67     });
68     if (is_quant) {
69       return false;
70     }
71   }
72 
73   auto scale_weight_node = scale_cnode->input(kScaleWeightIndex);
74   if (!IsParamNode(scale_weight_node)) {
75     return false;
76   }
77   if (scale_cnode->size() == kScaleWithBiasLen) {
78     auto scale_bias_node = scale_cnode->input(kScaleWeightIndex);
79     MS_CHECK_TRUE_RET(scale_bias_node != nullptr, false);
80     if (!IsParamNode(scale_bias_node)) {
81       return false;
82     }
83   }
84   return true;
85 }
86 
GetInputParamsAndTensors(const CNodePtr & up_scale_cnode,const CNodePtr & down_scale_cnode) const87 int ScaleScaleFusion::GetInputParamsAndTensors(const CNodePtr &up_scale_cnode, const CNodePtr &down_scale_cnode) const {
88   MS_ASSERT(up_scale_cnode != nullptr && down_scale_cnode != nullptr);
89   auto abstract = GetCNodeInputAbstract(up_scale_cnode, SECOND_INPUT);
90   if (abstract == nullptr) {
91     MS_LOG(ERROR) << "Get abstract failed.";
92     return lite::RET_ERROR;
93   }
94   if (FetchShapeFromAbstract(abstract, &scale_input_shape_) != lite::RET_OK) {
95     MS_LOG(ERROR) << "Fetch shape from abstract failed.";
96     return lite::RET_ERROR;
97   }
98   MS_CHECK_TRUE_RET(!scale_input_shape_.empty(), lite::RET_ERROR);
99 
100   auto up_scale_prim = ops::GetOperator<ops::ScaleFusion>(up_scale_cnode->input(FIRST_INPUT));
101   MS_CHECK_TRUE_RET(up_scale_prim != nullptr, lite::RET_ERROR);
102   auto up_scale_prim_c = up_scale_prim->GetPrim();
103   MS_CHECK_TRUE_RET(up_scale_prim_c != nullptr && up_scale_prim_c->GetAttr(ops::kAxis), lite::RET_ERROR);
104   auto axis = up_scale_prim->get_axis();
105   up_scale_axis_ = axis < 0 ? axis + static_cast<int>(scale_input_shape_.size()) : axis;
106   auto down_scale_prim = ops::GetOperator<ops::ScaleFusion>(down_scale_cnode->input(FIRST_INPUT));
107   MS_CHECK_TRUE_RET(down_scale_prim != nullptr, lite::RET_ERROR);
108   auto down_scale_prim_c = down_scale_prim->GetPrim();
109   MS_CHECK_TRUE_RET(down_scale_prim_c != nullptr && down_scale_prim_c->GetAttr(ops::kAxis), lite::RET_ERROR);
110   axis = down_scale_prim->get_axis();
111   down_scale_axis_ = axis < 0 ? axis + static_cast<int>(scale_input_shape_.size()) : axis;
112 
113   auto up_weight_param = up_scale_cnode->input(THIRD_INPUT);
114   MS_CHECK_TRUE_RET(up_weight_param != nullptr, lite::RET_ERROR);
115   up_weight_tensor_ = GetTensorInfo(up_weight_param);
116   MS_CHECK_TRUE_RET(up_weight_tensor_ != nullptr, lite::RET_ERROR);
117   MS_CHECK_TRUE_RET(
118     up_weight_tensor_->data_type() == kNumberTypeFloat || up_weight_tensor_->data_type() == kNumberTypeFloat32,
119     lite::RET_ERROR);
120   if (up_scale_cnode->size() == kScaleWithBiasLen) {
121     auto up_bias_param = up_scale_cnode->input(FOURTH_INPUT);
122     MS_CHECK_TRUE_RET(up_bias_param != nullptr, lite::RET_ERROR);
123     up_bias_tensor_ = GetTensorInfo(up_bias_param);
124     MS_CHECK_TRUE_RET(up_bias_tensor_ != nullptr, lite::RET_ERROR);
125     MS_CHECK_TRUE_RET(
126       up_bias_tensor_->data_type() == kNumberTypeFloat || up_bias_tensor_->data_type() == kNumberTypeFloat32,
127       lite::RET_ERROR);
128   }
129 
130   auto down_weight_param = down_scale_cnode->input(THIRD_INPUT);
131   MS_CHECK_TRUE_RET(down_weight_param != nullptr, lite::RET_ERROR);
132   down_weight_tensor_ = GetTensorInfo(down_weight_param);
133   MS_CHECK_TRUE_RET(down_weight_tensor_ != nullptr, lite::RET_ERROR);
134   MS_CHECK_TRUE_RET(
135     down_weight_tensor_->data_type() == kNumberTypeFloat || down_weight_tensor_->data_type() == kNumberTypeFloat32,
136     lite::RET_ERROR);
137   if (down_scale_cnode->size() == kScaleWithBiasLen) {
138     auto down_bias_param = down_scale_cnode->input(FOURTH_INPUT);
139     MS_CHECK_TRUE_RET(down_bias_param != nullptr, lite::RET_ERROR);
140     down_bias_tensor_ = GetTensorInfo(down_bias_param);
141     MS_CHECK_TRUE_RET(down_bias_tensor_ != nullptr, lite::RET_ERROR);
142     MS_CHECK_TRUE_RET(
143       down_bias_tensor_->data_type() == kNumberTypeFloat || down_bias_tensor_->data_type() == kNumberTypeFloat32,
144       lite::RET_ERROR);
145   }
146   return lite::RET_OK;
147 }
148 
GetMultiplyResultTensorInfo(const tensor::TensorPtr & left_tensor,const tensor::TensorPtr & right_tensor) const149 tensor::TensorPtr ScaleScaleFusion::GetMultiplyResultTensorInfo(const tensor::TensorPtr &left_tensor,
150                                                                 const tensor::TensorPtr &right_tensor) const {
151   MS_ASSERT(left_tensor != nullptr && right_tensor != nullptr);
152   auto left_weight_shape = left_tensor->shape();
153   auto right_weight_shape = right_tensor->shape();
154   size_t left_end_idx = up_scale_axis_ + left_weight_shape.size();
155   size_t right_end_idx = down_scale_axis_ + right_weight_shape.size();
156   auto begin_idx = MSMIN(up_scale_axis_, down_scale_axis_);
157   auto tmp_idx = MSMAX(up_scale_axis_, down_scale_axis_);
158   auto tmp_end_idx = up_scale_axis_ < down_scale_axis_ ? right_end_idx : left_end_idx;
159   size_t expand_size = 1;
160   for (size_t i = begin_idx; i < tmp_idx; i++) {
161     MS_CHECK_TRUE_RET(!SIZE_MUL_OVERFLOW(expand_size, static_cast<size_t>(scale_input_shape_.at(i))), nullptr);
162     expand_size *= static_cast<size_t>(scale_input_shape_.at(i));
163   }
164   size_t ele_size = 1;
165   for (size_t i = tmp_idx; i < tmp_end_idx; i++) {
166     MS_CHECK_TRUE_RET(!SIZE_MUL_OVERFLOW(ele_size, static_cast<size_t>(scale_input_shape_.at(i))), nullptr);
167     ele_size *= static_cast<size_t>(scale_input_shape_.at(i));
168   }
169   MS_CHECK_LE(expand_size * ele_size * sizeof(float), MAX_MALLOC_SIZE, nullptr);
170   float *expand_data = reinterpret_cast<float *>(malloc(expand_size * ele_size * sizeof(float)));
171   if (expand_data == nullptr) {
172     MS_LOG(ERROR) << "malloc data failed.";
173     return nullptr;
174   }
175   auto tmp_tensor = up_scale_axis_ < down_scale_axis_ ? right_tensor : left_tensor;
176   for (size_t i = 0; i < expand_size; i++) {
177     if (memcpy_s(expand_data + i * ele_size, ele_size * sizeof(float), tmp_tensor->data_c(), tmp_tensor->Size()) !=
178         EOK) {
179       MS_LOG(ERROR) << "memcpy data failed.";
180       free(expand_data);
181       return nullptr;
182     }
183   }
184 
185   float *left_data = nullptr;
186   float *right_data = nullptr;
187   if (up_scale_axis_ < down_scale_axis_) {
188     left_data = left_end_idx < right_end_idx ? static_cast<float *>(left_tensor->data_c()) : expand_data;
189     right_data = left_end_idx < right_end_idx ? expand_data : static_cast<float *>(left_tensor->data_c());
190   } else {
191     left_data = left_end_idx < right_end_idx ? expand_data : static_cast<float *>(right_tensor->data_c());
192     right_data = left_end_idx < right_end_idx ? static_cast<float *>(right_tensor->data_c()) : expand_data;
193   }
194   if (left_data == nullptr || right_data == nullptr) {
195     free(expand_data);
196     return nullptr;
197   }
198 
199   auto end_idx = MSMAX(left_end_idx, right_end_idx);
200   expand_shape_.assign(scale_input_shape_.begin() + begin_idx, scale_input_shape_.begin() + end_idx);
201   auto tensor_info = lite::CreateTensorInfo(nullptr, 0, expand_shape_, left_tensor->data_type());
202   if (tensor_info == nullptr) {
203     MS_LOG(ERROR) << "Create tensor info failed.";
204     free(expand_data);
205     return nullptr;
206   }
207   float *new_weight_data = reinterpret_cast<float *>(tensor_info->data_c());
208   MS_ASSERT(new_weight_data != nullptr);
209   size_t outer_size =
210     std::accumulate(scale_input_shape_.begin() + begin_idx,
211                     scale_input_shape_.begin() + MSMIN(left_end_idx, right_end_idx), 1, std::multiplies<size_t>());
212   size_t inner_size = std::accumulate(scale_input_shape_.begin() + MSMIN(left_end_idx, right_end_idx),
213                                       scale_input_shape_.begin() + end_idx, 1, std::multiplies<size_t>());
214   for (size_t i = 0; i < outer_size; i++) {
215     for (size_t j = 0; j < inner_size; j++) {
216       new_weight_data[i * inner_size + j] = left_data[i] * right_data[i * inner_size + j];
217     }
218   }
219   free(expand_data);
220   return tensor_info;
221 }
222 
GenerateNewWeightNode(const FuncGraphPtr & func_graph,const std::string & name) const223 ParameterPtr ScaleScaleFusion::GenerateNewWeightNode(const FuncGraphPtr &func_graph, const std::string &name) const {
224   auto param = func_graph->add_parameter();
225   MS_CHECK_TRUE_RET(param != nullptr, nullptr);
226   auto new_weight_tensor = GetMultiplyResultTensorInfo(up_weight_tensor_, down_weight_tensor_);
227   if (new_weight_tensor == nullptr) {
228     MS_LOG(ERROR) << "Get new weight tensor failed.";
229     return nullptr;
230   }
231   if (lite::InitParameterFromTensorInfo(param, new_weight_tensor) != lite::RET_OK) {
232     MS_LOG(ERROR) << "Init parameter from tensor info failed.";
233     return nullptr;
234   }
235   param->set_name(name);
236   return param;
237 }
238 
GenerateNewBiasNode(const FuncGraphPtr & func_graph,const std::string & name) const239 ParameterPtr ScaleScaleFusion::GenerateNewBiasNode(const FuncGraphPtr &func_graph, const std::string &name) const {
240   auto param = func_graph->add_parameter();
241   MS_CHECK_TRUE_RET(param != nullptr, nullptr);
242   tensor::TensorPtr tensor_info = GetMultiplyResultTensorInfo(up_bias_tensor_, down_weight_tensor_);
243   if (tensor_info == nullptr) {
244     MS_LOG(ERROR) << "Create tensor info failed.";
245     return nullptr;
246   }
247   if (down_bias_tensor_ != nullptr) {
248     auto bias_shape = down_bias_tensor_->shape_c();
249     int axis_diff = down_scale_axis_ - MSMIN(up_scale_axis_, down_scale_axis_);
250     int end_idx_diff = static_cast<int>(down_scale_axis_ + bias_shape.size()) -
251                        static_cast<int>(MSMAX(down_scale_axis_ + bias_shape.size(),
252                                               up_scale_axis_ + up_weight_tensor_->shape_c().size()));
253     size_t outer_size = axis_diff > 0 ? std::accumulate(expand_shape_.begin(), expand_shape_.begin() + axis_diff, 1,
254                                                         std::multiplies<size_t>())
255                                       : 1;
256     size_t inner_size = end_idx_diff < 0 ? std::accumulate(expand_shape_.end() + end_idx_diff, expand_shape_.end(), 1,
257                                                            std::multiplies<size_t>())
258                                          : 1;
259     size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), 1, std::multiplies<size_t>());
260     float *bias_data = reinterpret_cast<float *>(down_bias_tensor_->data_c());
261     float *data = reinterpret_cast<float *>(tensor_info->data_c());
262     MS_ASSERT(bias_data != nullptr && data != nullptr);
263     for (size_t i = 0; i < outer_size; i++) {
264       for (size_t j = 0; j < bias_size; j++) {
265         for (size_t k = 0; k < inner_size; k++) {
266           data[i * bias_size * inner_size + j * inner_size + k] += bias_data[j];
267         }
268       }
269     }
270   }
271   if (lite::InitParameterFromTensorInfo(param, tensor_info) != lite::RET_OK) {
272     MS_LOG(ERROR) << "Init parameter from tensor info failed.";
273     return nullptr;
274   }
275   param->set_name(name);
276   return param;
277 }
278 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const279 const AnfNodePtr ScaleScaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
280                                            const EquivPtr &) const {
281   MS_ASSERT(func_graph != nullptr && node != nullptr);
282   auto down_scale_cnode = node->cast<CNodePtr>();
283   MS_CHECK_TRUE_RET(down_scale_cnode != nullptr, nullptr);
284   auto up_scale_node = down_scale_cnode->input(SECOND_INPUT);
285   MS_CHECK_TRUE_RET(up_scale_node != nullptr, nullptr);
286   auto up_scale_cnode = up_scale_node->cast<CNodePtr>();
287   MS_CHECK_TRUE_RET(up_scale_cnode != nullptr, nullptr);
288   if (IsMultiOutputTensors(func_graph, up_scale_cnode)) {
289     return nullptr;
290   }
291   if (!CheckScaleNode(up_scale_cnode) || !CheckScaleNode(down_scale_cnode)) {
292     return nullptr;
293   }
294   auto scale_prim = ops::GetOperator<ops::ScaleFusion>(up_scale_cnode->input(FIRST_INPUT));
295   MS_CHECK_TRUE_RET(scale_prim != nullptr, nullptr);
296   auto scale_prim_c = scale_prim->GetPrim();
297   MS_CHECK_TRUE_RET(scale_prim_c != nullptr, nullptr);
298   if (scale_prim_c->GetAttr(ops::kActivationType) != nullptr && scale_prim->get_activation_type() != NO_ACTIVATION) {
299     return nullptr;
300   }
301 
302   if (GetInputParamsAndTensors(up_scale_cnode, down_scale_cnode) != lite::RET_OK) {
303     MS_LOG(ERROR) << "Get inputs failed.";
304     return nullptr;
305   }
306   auto new_weight_param = GenerateNewWeightNode(func_graph, down_scale_cnode->fullname_with_scope() + "_weight");
307   if (new_weight_param == nullptr) {
308     MS_LOG(ERROR) << "Generate new weight parameter node failed.";
309     return nullptr;
310   }
311   auto down_scale_prim = ops::GetOperator<ops::ScaleFusion>(down_scale_cnode->input(FIRST_INPUT));
312   MS_CHECK_TRUE_RET(down_scale_prim != nullptr, nullptr);
313   auto down_scale_prim_c = down_scale_prim->GetPrim();
314   MS_CHECK_TRUE_RET(down_scale_prim_c != nullptr && down_scale_prim_c->GetAttr(ops::kAxis) != nullptr, nullptr);
315   down_scale_prim->set_axis(MSMIN(up_scale_axis_, down_scale_axis_));
316 
317   auto manager = func_graph->manager();
318   MS_ASSERT(manager != nullptr);
319   manager->SetEdge(down_scale_cnode, 1, up_scale_cnode->input(SECOND_INPUT));
320   manager->SetEdge(down_scale_cnode, kInputIndexTwo, new_weight_param);
321   if (up_scale_cnode->size() == kScaleWithBiasLen) {
322     ParameterPtr new_bias_param = GenerateNewBiasNode(func_graph, down_scale_cnode->fullname_with_scope() + "_bias");
323     if (new_bias_param == nullptr) {
324       MS_LOG(ERROR) << "Generate new weight parameter node failed.";
325       return nullptr;
326     }
327     if (down_scale_cnode->size() == kScaleWithBiasLen) {
328       manager->SetEdge(down_scale_cnode, kInputIndexThree, new_bias_param);
329     } else {
330       manager->AddEdge(down_scale_cnode, new_bias_param);
331     }
332   }
333 
334   return nullptr;
335 }
336 }  // namespace mindspore::opt
337