1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "extendrt/delegate/ascend_ge/update_weight.h"
18 #include <string>
19 #include "ops/auto_generate/gen_lite_ops.h"
20 #include "mindspore/core/ops/math_ops.h"
21 #include "tools/common/string_util.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "mindspore/core/ir/manager.h"
24 #include "tools/common/tensor_util.h"
25 #include "mindspore/core/ops/conv_pool_ops.h"
26 namespace mindspore {
27 namespace {
28 constexpr float kNumMicrosecondToMillisecond = 1000.0;
29 constexpr size_t kInputSize3 = 3;
30 constexpr size_t kConstantWeightShapeSize = 2;
31 constexpr size_t kConstantConvWeightShapeSize = 4;
32 constexpr size_t kInputIndex2 = 2;
33 constexpr const char *kUpdateWeightTensorNameSuffix = "_add_param";
34 constexpr const char *kUpdateWeightAddNodeNameSuffix = "_add_cnode";
35 constexpr std::size_t kUpdateWeightTensorNameSuffixSize = 10;
36 constexpr size_t kConvWeightSize = 4;
37 constexpr size_t kConvWeightShape0 = 0;
38 constexpr size_t kConvWeightShape1 = 1;
39 constexpr size_t kConvWeightShape2 = 2;
40 constexpr size_t kConvWeightShape3 = 3;
41 } // namespace
42
IsMatchName(const std::string & cnode_name,const std::string & param_name)43 bool UpdateWeight::IsMatchName(const std::string &cnode_name, const std::string ¶m_name) {
44 if (find(constant_cnode_name_.begin(), constant_cnode_name_.end(), cnode_name) != constant_cnode_name_.end()) {
45 MS_LOG(DEBUG) << "cnode name: " << cnode_name << ", param name: " << param_name;
46 return true;
47 }
48 return false;
49 }
50
ParseUpdateWeightConfig(const std::string & names_str)51 bool UpdateWeight::ParseUpdateWeightConfig(const std::string &names_str) {
52 MS_LOG(DEBUG) << "names str: " << names_str;
53 constant_cnode_name_ = mindspore::lite::SplitStringToVector(names_str, ',');
54 if (constant_cnode_name_.empty()) {
55 MS_LOG(ERROR) << "split name is empty, name str is: " << names_str;
56 return false;
57 }
58 return true;
59 }
60
GetVariableParamsName(const FuncGraphPtr & anf_graph)61 std::vector<std::string> UpdateWeight::GetVariableParamsName(const FuncGraphPtr &anf_graph) {
62 return new_weight_param_name_;
63 }
64
SetInitDataNames(const std::vector<std::string> & init_data_names)65 bool UpdateWeight::SetInitDataNames(const std::vector<std::string> &init_data_names) {
66 if (init_data_names.empty()) {
67 MS_LOG(ERROR) << "init_data_names is empty.";
68 return false;
69 }
70 init_data_names_ = init_data_names;
71 return true;
72 }
73
UpdateConstantTensorData(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> & weights,std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> * new_weights)74 bool UpdateWeight::UpdateConstantTensorData(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> &weights,
75 std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> *new_weights) {
76 // sort by init data name.
77 if (new_weights == nullptr) {
78 MS_LOG(ERROR) << "new_weight_tensors is nullptr.";
79 return false;
80 }
81 auto time1 = lite::GetTimeUs();
82 for (auto &weight : weights) {
83 std::vector<std::shared_ptr<tensor::Tensor>> new_weight_tensors;
84 std::map<std::string, std::shared_ptr<tensor::Tensor>> weights_pairs;
85 for (auto tensor : weight) {
86 MS_CHECK_TRUE_RET(tensor != nullptr, false);
87 weights_pairs[tensor->name()] = tensor;
88 }
89 for (auto &init_data_name : init_data_names_) {
90 auto size = init_data_name.size();
91 if (size < kUpdateWeightTensorNameSuffixSize) {
92 MS_LOG(ERROR) << "can not find init data name: " << init_data_name;
93 return false;
94 }
95 size_t last_slash_pos = init_data_name.find_last_of('/');
96 auto name = init_data_name.substr(0, last_slash_pos);
97 if (weights_pairs.find(name) == weights_pairs.end()) {
98 MS_LOG(ERROR) << "can not find init data name in user update weight tensors.";
99 return false;
100 }
101 auto weight_tensor = weights_pairs[name];
102 weight_tensor->set_name(init_data_name);
103 new_weight_tensors.push_back(weight_tensor);
104 }
105 new_weights->push_back(new_weight_tensors);
106 }
107 auto time2 = lite::GetTimeUs();
108 MS_LOG(INFO) << "Calculate update tensor time: " << (time2 - time1) / kNumMicrosecondToMillisecond << " ms";
109 return true;
110 }
111
BuildFloatVec4DParameterNode(const FuncGraphPtr & anf_graph,ShapeVector weight_shape,const std::string & node_name)112 ParameterPtr UpdateWeight::BuildFloatVec4DParameterNode(const FuncGraphPtr &anf_graph, ShapeVector weight_shape,
113 const std::string &node_name) {
114 if (weight_shape.size() != kConvWeightSize) {
115 MS_LOG(ERROR) << "weight_shape size is not 4, weight_shape size:" << weight_shape.size() << "!";
116 return nullptr;
117 }
118 MS_CHECK_TRUE_RET(anf_graph != nullptr, nullptr);
119 auto param_node = anf_graph->add_parameter();
120 MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
121 param_node->set_name(node_name);
122 auto weight_length = weight_shape[kConvWeightShape0] * weight_shape[kConvWeightShape1] *
123 weight_shape[kConvWeightShape2] * weight_shape[kConvWeightShape3];
124 std::vector<float> data_1d(weight_length, 0);
125 auto size = data_1d.size() * sizeof(float);
126 std::vector<int64_t> shape_vector = {
127 static_cast<int64_t>(weight_shape[kConvWeightShape0]), static_cast<int64_t>(weight_shape[kConvWeightShape1]),
128 static_cast<int64_t>(weight_shape[kConvWeightShape2]), static_cast<int64_t>(weight_shape[kConvWeightShape3])};
129 auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeFloat32);
130 if (tensor_info == nullptr) {
131 MS_LOG(ERROR) << "Create tensor info failed!";
132 return nullptr;
133 }
134 auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
135 if (status != RET_OK) {
136 MS_LOG(ERROR) << "init parameter from tensor info failed!";
137 return nullptr;
138 }
139 return param_node;
140 }
141
JudgeNodeType(const AnfNodePtr & node)142 bool JudgeNodeType(const AnfNodePtr &node) {
143 return !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) &&
144 !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) &&
145 !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMul) &&
146 !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul);
147 }
148
CreateAddOpNodeForGraph(const FuncGraphPtr & anf_graph)149 bool UpdateWeight::CreateAddOpNodeForGraph(const FuncGraphPtr &anf_graph) {
150 MS_CHECK_TRUE_RET(anf_graph != nullptr, false);
151 if (constant_cnode_name_.empty()) {
152 MS_LOG(ERROR) << "constant_cnode_name_ is empty, user not set config file for update weight!";
153 return false;
154 }
155 auto node_list = TopoSort(anf_graph->get_return());
156 for (auto &node : node_list) {
157 MS_CHECK_TRUE_RET(node != nullptr, false);
158 if (!utils::isa<CNodePtr>(node)) {
159 continue;
160 }
161 auto cnode = utils::cast<CNodePtr>(node);
162 MS_CHECK_TRUE_RET(cnode != nullptr, false);
163 size_t last_slash_pos = cnode->fullname_with_scope().find_last_of('/');
164 string search_key = "";
165 if (last_slash_pos != std::string::npos) {
166 search_key = cnode->fullname_with_scope().substr(0, last_slash_pos);
167 } else {
168 MS_LOG(INFO) << "Find last slash failed! Cnode name:" << cnode->fullname_with_scope() << "!";
169 }
170 if (find(constant_cnode_name_.begin(), constant_cnode_name_.end(), search_key) == constant_cnode_name_.end()) {
171 continue;
172 } else if (JudgeNodeType(node)) {
173 continue;
174 }
175 if (cnode->size() < kInputSize3) {
176 MS_LOG(ERROR) << "cnode input size less " << kInputSize3;
177 return false;
178 }
179 auto weight = cnode->input(kInputIndex2);
180 MS_CHECK_TRUE_RET(weight != nullptr, false);
181
182 // create Add node
183 auto add_prim = std::make_shared<ops::Add>();
184 if (add_prim == nullptr) {
185 MS_LOG(ERROR) << "create add prim failed.";
186 return false;
187 }
188 auto add_prim_c = add_prim->GetPrim();
189 MS_CHECK_TRUE_RET(add_prim_c != nullptr, false);
190 if (!utils::isa<ParameterPtr>(weight)) {
191 MS_LOG(ERROR) << "matmul weight is not constant, can not update weight.";
192 return false;
193 }
194 auto weight_param = weight->cast<ParameterPtr>();
195 MS_CHECK_TRUE_RET(weight_param != nullptr, false);
196 auto value = weight_param->default_param();
197 MS_CHECK_TRUE_RET(value != nullptr, false);
198 auto weight_tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
199 MS_CHECK_TRUE_RET(weight_tensor != nullptr, false);
200 auto weight_shape = weight_tensor->shape();
201 AnfNodePtr add_param_node = nullptr;
202 if (weight_shape.size() == kConstantWeightShapeSize) {
203 std::vector<std::vector<float>> add_param_data(weight_shape[0], std::vector<float>(weight_shape[1], 0));
204 add_param_node = opt::BuildFloatVec2DParameterNode(anf_graph, add_param_data,
205 cnode->fullname_with_scope() + kUpdateWeightTensorNameSuffix);
206 if (add_param_node == nullptr) {
207 MS_LOG(ERROR) << "create param node failed!";
208 return false;
209 }
210 } else if (weight_shape.size() == kConstantConvWeightShapeSize) {
211 add_param_node = BuildFloatVec4DParameterNode(anf_graph, weight_shape,
212 cnode->fullname_with_scope() + kUpdateWeightTensorNameSuffix);
213 if (add_param_node == nullptr) {
214 MS_LOG(ERROR) << "create param node failed!";
215 return false;
216 }
217 } else {
218 MS_LOG(ERROR) << "now only support 2 dims matmul and 4 dims conv constant weight!"
219 << "weight_shape:" << weight_shape.size() << "node name:" << cnode->fullname_with_scope() << "!";
220 return false;
221 }
222
223 if (add_param_node == nullptr) {
224 MS_LOG(ERROR) << "create param node failed!";
225 return false;
226 }
227 new_weight_param_name_.push_back(cnode->fullname_with_scope() + "_add_param");
228 auto inputs = {weight, add_param_node};
229 auto add_cnode = anf_graph->NewCNode(add_prim_c, inputs);
230 if (add_cnode == nullptr) {
231 MS_LOG(ERROR) << "new add node failed.";
232 return false;
233 }
234 add_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + kUpdateWeightAddNodeNameSuffix);
235 if (node->abstract() != nullptr) {
236 add_cnode->set_abstract(node->abstract()->Clone());
237 }
238 auto manager = Manage(anf_graph);
239 (void)manager->Replace(weight, add_cnode);
240 }
241 if (new_weight_param_name_.size() != constant_cnode_name_.size()) {
242 MS_LOG(ERROR) << "init data name size is not equal user config file name size, new_weight_param_name_: "
243 << new_weight_param_name_.size() << ", constant_cnode_name_ size: " << constant_cnode_name_.size();
244 }
245 MS_LOG(INFO) << "new_weight_param_name_ size: " << new_weight_param_name_.size()
246 << ", constant_cnode_name_ size: " << constant_cnode_name_.size();
247 return true;
248 }
249 } // namespace mindspore
250