1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/scale_scale_fusion.h"
19 #include <functional>
20 #include <memory>
21 #include "mindspore/core/ops/lite_ops.h"
22 #include "tools/converter/quantizer/quant_param_holder.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "tools/common/tensor_util.h"
25 #include "ops/fusion/scale_fusion.h"
26 #include "securec/include/securec.h"
27 #include "nnacl/op_base.h"
28 #include "ops/op_utils.h"
29
30 namespace mindspore::opt {
31 namespace {
32 constexpr size_t kScaleWeightIndex = 2;
33 constexpr size_t kScaleBiasIndex = 3;
34 constexpr size_t kScaleNoBiasLen = 3;
35 constexpr size_t kScaleWithBiasLen = 4;
36 } // namespace
37
DefinePattern() const38 const BaseRef ScaleScaleFusion::DefinePattern() const {
39 auto is_scale_up = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
40 MS_CHECK_TRUE_RET(is_scale_up != nullptr, {});
41 auto is_param = std::make_shared<CondVar>(IsParamNode);
42 MS_CHECK_TRUE_RET(is_param != nullptr, {});
43 auto is_seq_var = std::make_shared<SeqVar>();
44 MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
45 auto is_scale_down = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
46 MS_CHECK_TRUE_RET(is_scale_down != nullptr, {});
47 return VectorRef({is_scale_down, is_scale_up, is_param, is_seq_var});
48 }
49
CheckScaleNode(const CNodePtr & scale_cnode) const50 bool ScaleScaleFusion::CheckScaleNode(const CNodePtr &scale_cnode) const {
51 MS_ASSERT(scale_cnode != nullptr);
52 if (IsMarkedTrainOp(scale_cnode)) {
53 return false;
54 }
55 MS_CHECK_TRUE_RET(scale_cnode->size() >= kScaleNoBiasLen, false);
56 auto scale_prim = ops::GetOperator<ops::ScaleFusion>(scale_cnode->input(FIRST_INPUT));
57 MS_CHECK_TRUE_RET(scale_prim != nullptr, false);
58 auto scale_prim_c = scale_prim->GetPrim();
59 MS_CHECK_TRUE_RET(scale_prim_c != nullptr, false);
60 auto quant_attr = scale_prim_c->GetAttr("quant_params");
61 if (quant_attr != nullptr) {
62 auto quant_param_holder = quant_attr->cast<lite::QuantParamHolderPtr>();
63 MS_CHECK_TRUE_RET(quant_param_holder != nullptr, false);
64 auto quant_params = quant_param_holder->get_input_quant_params();
65 bool is_quant = std::any_of(quant_params.begin(), quant_params.end(), [](std::vector<schema::QuantParamT> ¶ms) {
66 return !params.empty() && params.front().inited;
67 });
68 if (is_quant) {
69 return false;
70 }
71 }
72
73 auto scale_weight_node = scale_cnode->input(kScaleWeightIndex);
74 if (!IsParamNode(scale_weight_node)) {
75 return false;
76 }
77 if (scale_cnode->size() == kScaleWithBiasLen) {
78 auto scale_bias_node = scale_cnode->input(kScaleWeightIndex);
79 MS_CHECK_TRUE_RET(scale_bias_node != nullptr, false);
80 if (!IsParamNode(scale_bias_node)) {
81 return false;
82 }
83 }
84 return true;
85 }
86
GetInputParamsAndTensors(const CNodePtr & up_scale_cnode,const CNodePtr & down_scale_cnode) const87 int ScaleScaleFusion::GetInputParamsAndTensors(const CNodePtr &up_scale_cnode, const CNodePtr &down_scale_cnode) const {
88 MS_ASSERT(up_scale_cnode != nullptr && down_scale_cnode != nullptr);
89 auto abstract = GetCNodeInputAbstract(up_scale_cnode, SECOND_INPUT);
90 if (abstract == nullptr) {
91 MS_LOG(ERROR) << "Get abstract failed.";
92 return lite::RET_ERROR;
93 }
94 if (FetchShapeFromAbstract(abstract, &scale_input_shape_) != lite::RET_OK) {
95 MS_LOG(ERROR) << "Fetch shape from abstract failed.";
96 return lite::RET_ERROR;
97 }
98 MS_CHECK_TRUE_RET(!scale_input_shape_.empty(), lite::RET_ERROR);
99
100 auto up_scale_prim = ops::GetOperator<ops::ScaleFusion>(up_scale_cnode->input(FIRST_INPUT));
101 MS_CHECK_TRUE_RET(up_scale_prim != nullptr, lite::RET_ERROR);
102 auto up_scale_prim_c = up_scale_prim->GetPrim();
103 MS_CHECK_TRUE_RET(up_scale_prim_c != nullptr && up_scale_prim_c->GetAttr(ops::kAxis), lite::RET_ERROR);
104 auto axis = up_scale_prim->get_axis();
105 up_scale_axis_ = axis < 0 ? axis + static_cast<int>(scale_input_shape_.size()) : axis;
106 auto down_scale_prim = ops::GetOperator<ops::ScaleFusion>(down_scale_cnode->input(FIRST_INPUT));
107 MS_CHECK_TRUE_RET(down_scale_prim != nullptr, lite::RET_ERROR);
108 auto down_scale_prim_c = down_scale_prim->GetPrim();
109 MS_CHECK_TRUE_RET(down_scale_prim_c != nullptr && down_scale_prim_c->GetAttr(ops::kAxis), lite::RET_ERROR);
110 axis = down_scale_prim->get_axis();
111 down_scale_axis_ = axis < 0 ? axis + static_cast<int>(scale_input_shape_.size()) : axis;
112
113 auto up_weight_param = up_scale_cnode->input(THIRD_INPUT);
114 MS_CHECK_TRUE_RET(up_weight_param != nullptr, lite::RET_ERROR);
115 up_weight_tensor_ = GetTensorInfo(up_weight_param);
116 MS_CHECK_TRUE_RET(up_weight_tensor_ != nullptr, lite::RET_ERROR);
117 MS_CHECK_TRUE_RET(
118 up_weight_tensor_->data_type() == kNumberTypeFloat || up_weight_tensor_->data_type() == kNumberTypeFloat32,
119 lite::RET_ERROR);
120 if (up_scale_cnode->size() == kScaleWithBiasLen) {
121 auto up_bias_param = up_scale_cnode->input(FOURTH_INPUT);
122 MS_CHECK_TRUE_RET(up_bias_param != nullptr, lite::RET_ERROR);
123 up_bias_tensor_ = GetTensorInfo(up_bias_param);
124 MS_CHECK_TRUE_RET(up_bias_tensor_ != nullptr, lite::RET_ERROR);
125 MS_CHECK_TRUE_RET(
126 up_bias_tensor_->data_type() == kNumberTypeFloat || up_bias_tensor_->data_type() == kNumberTypeFloat32,
127 lite::RET_ERROR);
128 }
129
130 auto down_weight_param = down_scale_cnode->input(THIRD_INPUT);
131 MS_CHECK_TRUE_RET(down_weight_param != nullptr, lite::RET_ERROR);
132 down_weight_tensor_ = GetTensorInfo(down_weight_param);
133 MS_CHECK_TRUE_RET(down_weight_tensor_ != nullptr, lite::RET_ERROR);
134 MS_CHECK_TRUE_RET(
135 down_weight_tensor_->data_type() == kNumberTypeFloat || down_weight_tensor_->data_type() == kNumberTypeFloat32,
136 lite::RET_ERROR);
137 if (down_scale_cnode->size() == kScaleWithBiasLen) {
138 auto down_bias_param = down_scale_cnode->input(FOURTH_INPUT);
139 MS_CHECK_TRUE_RET(down_bias_param != nullptr, lite::RET_ERROR);
140 down_bias_tensor_ = GetTensorInfo(down_bias_param);
141 MS_CHECK_TRUE_RET(down_bias_tensor_ != nullptr, lite::RET_ERROR);
142 MS_CHECK_TRUE_RET(
143 down_bias_tensor_->data_type() == kNumberTypeFloat || down_bias_tensor_->data_type() == kNumberTypeFloat32,
144 lite::RET_ERROR);
145 }
146 return lite::RET_OK;
147 }
148
GetMultiplyResultTensorInfo(const tensor::TensorPtr & left_tensor,const tensor::TensorPtr & right_tensor) const149 tensor::TensorPtr ScaleScaleFusion::GetMultiplyResultTensorInfo(const tensor::TensorPtr &left_tensor,
150 const tensor::TensorPtr &right_tensor) const {
151 MS_ASSERT(left_tensor != nullptr && right_tensor != nullptr);
152 auto left_weight_shape = left_tensor->shape();
153 auto right_weight_shape = right_tensor->shape();
154 size_t left_end_idx = up_scale_axis_ + left_weight_shape.size();
155 size_t right_end_idx = down_scale_axis_ + right_weight_shape.size();
156 auto begin_idx = MSMIN(up_scale_axis_, down_scale_axis_);
157 auto tmp_idx = MSMAX(up_scale_axis_, down_scale_axis_);
158 auto tmp_end_idx = up_scale_axis_ < down_scale_axis_ ? right_end_idx : left_end_idx;
159 size_t expand_size = 1;
160 for (size_t i = begin_idx; i < tmp_idx; i++) {
161 MS_CHECK_TRUE_RET(!SIZE_MUL_OVERFLOW(expand_size, static_cast<size_t>(scale_input_shape_.at(i))), nullptr);
162 expand_size *= static_cast<size_t>(scale_input_shape_.at(i));
163 }
164 size_t ele_size = 1;
165 for (size_t i = tmp_idx; i < tmp_end_idx; i++) {
166 MS_CHECK_TRUE_RET(!SIZE_MUL_OVERFLOW(ele_size, static_cast<size_t>(scale_input_shape_.at(i))), nullptr);
167 ele_size *= static_cast<size_t>(scale_input_shape_.at(i));
168 }
169 MS_CHECK_LE(expand_size * ele_size * sizeof(float), MAX_MALLOC_SIZE, nullptr);
170 float *expand_data = reinterpret_cast<float *>(malloc(expand_size * ele_size * sizeof(float)));
171 if (expand_data == nullptr) {
172 MS_LOG(ERROR) << "malloc data failed.";
173 return nullptr;
174 }
175 auto tmp_tensor = up_scale_axis_ < down_scale_axis_ ? right_tensor : left_tensor;
176 for (size_t i = 0; i < expand_size; i++) {
177 if (memcpy_s(expand_data + i * ele_size, ele_size * sizeof(float), tmp_tensor->data_c(), tmp_tensor->Size()) !=
178 EOK) {
179 MS_LOG(ERROR) << "memcpy data failed.";
180 free(expand_data);
181 return nullptr;
182 }
183 }
184
185 float *left_data = nullptr;
186 float *right_data = nullptr;
187 if (up_scale_axis_ < down_scale_axis_) {
188 left_data = left_end_idx < right_end_idx ? static_cast<float *>(left_tensor->data_c()) : expand_data;
189 right_data = left_end_idx < right_end_idx ? expand_data : static_cast<float *>(left_tensor->data_c());
190 } else {
191 left_data = left_end_idx < right_end_idx ? expand_data : static_cast<float *>(right_tensor->data_c());
192 right_data = left_end_idx < right_end_idx ? static_cast<float *>(right_tensor->data_c()) : expand_data;
193 }
194 if (left_data == nullptr || right_data == nullptr) {
195 free(expand_data);
196 return nullptr;
197 }
198
199 auto end_idx = MSMAX(left_end_idx, right_end_idx);
200 expand_shape_.assign(scale_input_shape_.begin() + begin_idx, scale_input_shape_.begin() + end_idx);
201 auto tensor_info = lite::CreateTensorInfo(nullptr, 0, expand_shape_, left_tensor->data_type());
202 if (tensor_info == nullptr) {
203 MS_LOG(ERROR) << "Create tensor info failed.";
204 free(expand_data);
205 return nullptr;
206 }
207 float *new_weight_data = reinterpret_cast<float *>(tensor_info->data_c());
208 MS_ASSERT(new_weight_data != nullptr);
209 size_t outer_size =
210 std::accumulate(scale_input_shape_.begin() + begin_idx,
211 scale_input_shape_.begin() + MSMIN(left_end_idx, right_end_idx), 1, std::multiplies<size_t>());
212 size_t inner_size = std::accumulate(scale_input_shape_.begin() + MSMIN(left_end_idx, right_end_idx),
213 scale_input_shape_.begin() + end_idx, 1, std::multiplies<size_t>());
214 for (size_t i = 0; i < outer_size; i++) {
215 for (size_t j = 0; j < inner_size; j++) {
216 new_weight_data[i * inner_size + j] = left_data[i] * right_data[i * inner_size + j];
217 }
218 }
219 free(expand_data);
220 return tensor_info;
221 }
222
GenerateNewWeightNode(const FuncGraphPtr & func_graph,const std::string & name) const223 ParameterPtr ScaleScaleFusion::GenerateNewWeightNode(const FuncGraphPtr &func_graph, const std::string &name) const {
224 auto param = func_graph->add_parameter();
225 MS_CHECK_TRUE_RET(param != nullptr, nullptr);
226 auto new_weight_tensor = GetMultiplyResultTensorInfo(up_weight_tensor_, down_weight_tensor_);
227 if (new_weight_tensor == nullptr) {
228 MS_LOG(ERROR) << "Get new weight tensor failed.";
229 return nullptr;
230 }
231 if (lite::InitParameterFromTensorInfo(param, new_weight_tensor) != lite::RET_OK) {
232 MS_LOG(ERROR) << "Init parameter from tensor info failed.";
233 return nullptr;
234 }
235 param->set_name(name);
236 return param;
237 }
238
GenerateNewBiasNode(const FuncGraphPtr & func_graph,const std::string & name) const239 ParameterPtr ScaleScaleFusion::GenerateNewBiasNode(const FuncGraphPtr &func_graph, const std::string &name) const {
240 auto param = func_graph->add_parameter();
241 MS_CHECK_TRUE_RET(param != nullptr, nullptr);
242 tensor::TensorPtr tensor_info = GetMultiplyResultTensorInfo(up_bias_tensor_, down_weight_tensor_);
243 if (tensor_info == nullptr) {
244 MS_LOG(ERROR) << "Create tensor info failed.";
245 return nullptr;
246 }
247 if (down_bias_tensor_ != nullptr) {
248 auto bias_shape = down_bias_tensor_->shape_c();
249 int axis_diff = down_scale_axis_ - MSMIN(up_scale_axis_, down_scale_axis_);
250 int end_idx_diff = static_cast<int>(down_scale_axis_ + bias_shape.size()) -
251 static_cast<int>(MSMAX(down_scale_axis_ + bias_shape.size(),
252 up_scale_axis_ + up_weight_tensor_->shape_c().size()));
253 size_t outer_size = axis_diff > 0 ? std::accumulate(expand_shape_.begin(), expand_shape_.begin() + axis_diff, 1,
254 std::multiplies<size_t>())
255 : 1;
256 size_t inner_size = end_idx_diff < 0 ? std::accumulate(expand_shape_.end() + end_idx_diff, expand_shape_.end(), 1,
257 std::multiplies<size_t>())
258 : 1;
259 size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), 1, std::multiplies<size_t>());
260 float *bias_data = reinterpret_cast<float *>(down_bias_tensor_->data_c());
261 float *data = reinterpret_cast<float *>(tensor_info->data_c());
262 MS_ASSERT(bias_data != nullptr && data != nullptr);
263 for (size_t i = 0; i < outer_size; i++) {
264 for (size_t j = 0; j < bias_size; j++) {
265 for (size_t k = 0; k < inner_size; k++) {
266 data[i * bias_size * inner_size + j * inner_size + k] += bias_data[j];
267 }
268 }
269 }
270 }
271 if (lite::InitParameterFromTensorInfo(param, tensor_info) != lite::RET_OK) {
272 MS_LOG(ERROR) << "Init parameter from tensor info failed.";
273 return nullptr;
274 }
275 param->set_name(name);
276 return param;
277 }
278
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const279 const AnfNodePtr ScaleScaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
280 const EquivPtr &) const {
281 MS_ASSERT(func_graph != nullptr && node != nullptr);
282 auto down_scale_cnode = node->cast<CNodePtr>();
283 MS_CHECK_TRUE_RET(down_scale_cnode != nullptr, nullptr);
284 auto up_scale_node = down_scale_cnode->input(SECOND_INPUT);
285 MS_CHECK_TRUE_RET(up_scale_node != nullptr, nullptr);
286 auto up_scale_cnode = up_scale_node->cast<CNodePtr>();
287 MS_CHECK_TRUE_RET(up_scale_cnode != nullptr, nullptr);
288 if (IsMultiOutputTensors(func_graph, up_scale_cnode)) {
289 return nullptr;
290 }
291 if (!CheckScaleNode(up_scale_cnode) || !CheckScaleNode(down_scale_cnode)) {
292 return nullptr;
293 }
294 auto scale_prim = ops::GetOperator<ops::ScaleFusion>(up_scale_cnode->input(FIRST_INPUT));
295 MS_CHECK_TRUE_RET(scale_prim != nullptr, nullptr);
296 auto scale_prim_c = scale_prim->GetPrim();
297 MS_CHECK_TRUE_RET(scale_prim_c != nullptr, nullptr);
298 if (scale_prim_c->GetAttr(ops::kActivationType) != nullptr && scale_prim->get_activation_type() != NO_ACTIVATION) {
299 return nullptr;
300 }
301
302 if (GetInputParamsAndTensors(up_scale_cnode, down_scale_cnode) != lite::RET_OK) {
303 MS_LOG(ERROR) << "Get inputs failed.";
304 return nullptr;
305 }
306 auto new_weight_param = GenerateNewWeightNode(func_graph, down_scale_cnode->fullname_with_scope() + "_weight");
307 if (new_weight_param == nullptr) {
308 MS_LOG(ERROR) << "Generate new weight parameter node failed.";
309 return nullptr;
310 }
311 auto down_scale_prim = ops::GetOperator<ops::ScaleFusion>(down_scale_cnode->input(FIRST_INPUT));
312 MS_CHECK_TRUE_RET(down_scale_prim != nullptr, nullptr);
313 auto down_scale_prim_c = down_scale_prim->GetPrim();
314 MS_CHECK_TRUE_RET(down_scale_prim_c != nullptr && down_scale_prim_c->GetAttr(ops::kAxis) != nullptr, nullptr);
315 down_scale_prim->set_axis(MSMIN(up_scale_axis_, down_scale_axis_));
316
317 auto manager = func_graph->manager();
318 MS_ASSERT(manager != nullptr);
319 manager->SetEdge(down_scale_cnode, 1, up_scale_cnode->input(SECOND_INPUT));
320 manager->SetEdge(down_scale_cnode, kInputIndexTwo, new_weight_param);
321 if (up_scale_cnode->size() == kScaleWithBiasLen) {
322 ParameterPtr new_bias_param = GenerateNewBiasNode(func_graph, down_scale_cnode->fullname_with_scope() + "_bias");
323 if (new_bias_param == nullptr) {
324 MS_LOG(ERROR) << "Generate new weight parameter node failed.";
325 return nullptr;
326 }
327 if (down_scale_cnode->size() == kScaleWithBiasLen) {
328 manager->SetEdge(down_scale_cnode, kInputIndexThree, new_bias_param);
329 } else {
330 manager->AddEdge(down_scale_cnode, new_bias_param);
331 }
332 }
333
334 return nullptr;
335 }
336 } // namespace mindspore::opt
337