• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "extendrt/delegate/ascend_ge/update_weight.h"
18 #include <string>
19 #include "ops/auto_generate/gen_lite_ops.h"
20 #include "mindspore/core/ops/math_ops.h"
21 #include "tools/common/string_util.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "mindspore/core/ir/manager.h"
24 #include "tools/common/tensor_util.h"
25 #include "mindspore/core/ops/conv_pool_ops.h"
26 namespace mindspore {
27 namespace {
28 constexpr float kNumMicrosecondToMillisecond = 1000.0;
29 constexpr size_t kInputSize3 = 3;
30 constexpr size_t kConstantWeightShapeSize = 2;
31 constexpr size_t kConstantConvWeightShapeSize = 4;
32 constexpr size_t kInputIndex2 = 2;
33 constexpr const char *kUpdateWeightTensorNameSuffix = "_add_param";
34 constexpr const char *kUpdateWeightAddNodeNameSuffix = "_add_cnode";
35 constexpr std::size_t kUpdateWeightTensorNameSuffixSize = 10;
36 constexpr size_t kConvWeightSize = 4;
37 constexpr size_t kConvWeightShape0 = 0;
38 constexpr size_t kConvWeightShape1 = 1;
39 constexpr size_t kConvWeightShape2 = 2;
40 constexpr size_t kConvWeightShape3 = 3;
41 }  // namespace
42 
IsMatchName(const std::string & cnode_name,const std::string & param_name)43 bool UpdateWeight::IsMatchName(const std::string &cnode_name, const std::string &param_name) {
44   if (find(constant_cnode_name_.begin(), constant_cnode_name_.end(), cnode_name) != constant_cnode_name_.end()) {
45     MS_LOG(DEBUG) << "cnode name: " << cnode_name << ", param name: " << param_name;
46     return true;
47   }
48   return false;
49 }
50 
ParseUpdateWeightConfig(const std::string & names_str)51 bool UpdateWeight::ParseUpdateWeightConfig(const std::string &names_str) {
52   MS_LOG(DEBUG) << "names str: " << names_str;
53   constant_cnode_name_ = mindspore::lite::SplitStringToVector(names_str, ',');
54   if (constant_cnode_name_.empty()) {
55     MS_LOG(ERROR) << "split name is empty, name str is: " << names_str;
56     return false;
57   }
58   return true;
59 }
60 
GetVariableParamsName(const FuncGraphPtr & anf_graph)61 std::vector<std::string> UpdateWeight::GetVariableParamsName(const FuncGraphPtr &anf_graph) {
62   return new_weight_param_name_;
63 }
64 
SetInitDataNames(const std::vector<std::string> & init_data_names)65 bool UpdateWeight::SetInitDataNames(const std::vector<std::string> &init_data_names) {
66   if (init_data_names.empty()) {
67     MS_LOG(ERROR) << "init_data_names is empty.";
68     return false;
69   }
70   init_data_names_ = init_data_names;
71   return true;
72 }
73 
UpdateConstantTensorData(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> & weights,std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> * new_weights)74 bool UpdateWeight::UpdateConstantTensorData(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> &weights,
75                                             std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> *new_weights) {
76   // sort by init data name.
77   if (new_weights == nullptr) {
78     MS_LOG(ERROR) << "new_weight_tensors is nullptr.";
79     return false;
80   }
81   auto time1 = lite::GetTimeUs();
82   for (auto &weight : weights) {
83     std::vector<std::shared_ptr<tensor::Tensor>> new_weight_tensors;
84     std::map<std::string, std::shared_ptr<tensor::Tensor>> weights_pairs;
85     for (auto tensor : weight) {
86       MS_CHECK_TRUE_RET(tensor != nullptr, false);
87       weights_pairs[tensor->name()] = tensor;
88     }
89     for (auto &init_data_name : init_data_names_) {
90       auto size = init_data_name.size();
91       if (size < kUpdateWeightTensorNameSuffixSize) {
92         MS_LOG(ERROR) << "can not find init data name: " << init_data_name;
93         return false;
94       }
95       size_t last_slash_pos = init_data_name.find_last_of('/');
96       auto name = init_data_name.substr(0, last_slash_pos);
97       if (weights_pairs.find(name) == weights_pairs.end()) {
98         MS_LOG(ERROR) << "can not find init data name in user update weight tensors.";
99         return false;
100       }
101       auto weight_tensor = weights_pairs[name];
102       weight_tensor->set_name(init_data_name);
103       new_weight_tensors.push_back(weight_tensor);
104     }
105     new_weights->push_back(new_weight_tensors);
106   }
107   auto time2 = lite::GetTimeUs();
108   MS_LOG(INFO) << "Calculate update tensor time: " << (time2 - time1) / kNumMicrosecondToMillisecond << " ms";
109   return true;
110 }
111 
BuildFloatVec4DParameterNode(const FuncGraphPtr & anf_graph,ShapeVector weight_shape,const std::string & node_name)112 ParameterPtr UpdateWeight::BuildFloatVec4DParameterNode(const FuncGraphPtr &anf_graph, ShapeVector weight_shape,
113                                                         const std::string &node_name) {
114   if (weight_shape.size() != kConvWeightSize) {
115     MS_LOG(ERROR) << "weight_shape size is not 4, weight_shape size:" << weight_shape.size() << "!";
116     return nullptr;
117   }
118   MS_CHECK_TRUE_RET(anf_graph != nullptr, nullptr);
119   auto param_node = anf_graph->add_parameter();
120   MS_CHECK_TRUE_RET(param_node != nullptr, nullptr);
121   param_node->set_name(node_name);
122   auto weight_length = weight_shape[kConvWeightShape0] * weight_shape[kConvWeightShape1] *
123                        weight_shape[kConvWeightShape2] * weight_shape[kConvWeightShape3];
124   std::vector<float> data_1d(weight_length, 0);
125   auto size = data_1d.size() * sizeof(float);
126   std::vector<int64_t> shape_vector = {
127     static_cast<int64_t>(weight_shape[kConvWeightShape0]), static_cast<int64_t>(weight_shape[kConvWeightShape1]),
128     static_cast<int64_t>(weight_shape[kConvWeightShape2]), static_cast<int64_t>(weight_shape[kConvWeightShape3])};
129   auto tensor_info = lite::CreateTensorInfo(data_1d.data(), size, shape_vector, kNumberTypeFloat32);
130   if (tensor_info == nullptr) {
131     MS_LOG(ERROR) << "Create tensor info failed!";
132     return nullptr;
133   }
134   auto status = lite::InitParameterFromTensorInfo(param_node, tensor_info);
135   if (status != RET_OK) {
136     MS_LOG(ERROR) << "init parameter from tensor info failed!";
137     return nullptr;
138   }
139   return param_node;
140 }
141 
JudgeNodeType(const AnfNodePtr & node)142 bool JudgeNodeType(const AnfNodePtr &node) {
143   return !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) &&
144          !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) &&
145          !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMul) &&
146          !mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul);
147 }
148 
CreateAddOpNodeForGraph(const FuncGraphPtr & anf_graph)149 bool UpdateWeight::CreateAddOpNodeForGraph(const FuncGraphPtr &anf_graph) {
150   MS_CHECK_TRUE_RET(anf_graph != nullptr, false);
151   if (constant_cnode_name_.empty()) {
152     MS_LOG(ERROR) << "constant_cnode_name_ is empty, user not set config file for update weight!";
153     return false;
154   }
155   auto node_list = TopoSort(anf_graph->get_return());
156   for (auto &node : node_list) {
157     MS_CHECK_TRUE_RET(node != nullptr, false);
158     if (!utils::isa<CNodePtr>(node)) {
159       continue;
160     }
161     auto cnode = utils::cast<CNodePtr>(node);
162     MS_CHECK_TRUE_RET(cnode != nullptr, false);
163     size_t last_slash_pos = cnode->fullname_with_scope().find_last_of('/');
164     string search_key = "";
165     if (last_slash_pos != std::string::npos) {
166       search_key = cnode->fullname_with_scope().substr(0, last_slash_pos);
167     } else {
168       MS_LOG(INFO) << "Find last slash failed! Cnode name:" << cnode->fullname_with_scope() << "!";
169     }
170     if (find(constant_cnode_name_.begin(), constant_cnode_name_.end(), search_key) == constant_cnode_name_.end()) {
171       continue;
172     } else if (JudgeNodeType(node)) {
173       continue;
174     }
175     if (cnode->size() < kInputSize3) {
176       MS_LOG(ERROR) << "cnode input size less " << kInputSize3;
177       return false;
178     }
179     auto weight = cnode->input(kInputIndex2);
180     MS_CHECK_TRUE_RET(weight != nullptr, false);
181 
182     // create Add node
183     auto add_prim = std::make_shared<ops::Add>();
184     if (add_prim == nullptr) {
185       MS_LOG(ERROR) << "create add prim failed.";
186       return false;
187     }
188     auto add_prim_c = add_prim->GetPrim();
189     MS_CHECK_TRUE_RET(add_prim_c != nullptr, false);
190     if (!utils::isa<ParameterPtr>(weight)) {
191       MS_LOG(ERROR) << "matmul weight is not constant, can not update weight.";
192       return false;
193     }
194     auto weight_param = weight->cast<ParameterPtr>();
195     MS_CHECK_TRUE_RET(weight_param != nullptr, false);
196     auto value = weight_param->default_param();
197     MS_CHECK_TRUE_RET(value != nullptr, false);
198     auto weight_tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
199     MS_CHECK_TRUE_RET(weight_tensor != nullptr, false);
200     auto weight_shape = weight_tensor->shape();
201     AnfNodePtr add_param_node = nullptr;
202     if (weight_shape.size() == kConstantWeightShapeSize) {
203       std::vector<std::vector<float>> add_param_data(weight_shape[0], std::vector<float>(weight_shape[1], 0));
204       add_param_node = opt::BuildFloatVec2DParameterNode(anf_graph, add_param_data,
205                                                          cnode->fullname_with_scope() + kUpdateWeightTensorNameSuffix);
206       if (add_param_node == nullptr) {
207         MS_LOG(ERROR) << "create param node failed!";
208         return false;
209       }
210     } else if (weight_shape.size() == kConstantConvWeightShapeSize) {
211       add_param_node = BuildFloatVec4DParameterNode(anf_graph, weight_shape,
212                                                     cnode->fullname_with_scope() + kUpdateWeightTensorNameSuffix);
213       if (add_param_node == nullptr) {
214         MS_LOG(ERROR) << "create param node failed!";
215         return false;
216       }
217     } else {
218       MS_LOG(ERROR) << "now only support 2 dims matmul and 4 dims conv constant weight!"
219                     << "weight_shape:" << weight_shape.size() << "node name:" << cnode->fullname_with_scope() << "!";
220       return false;
221     }
222 
223     if (add_param_node == nullptr) {
224       MS_LOG(ERROR) << "create param node failed!";
225       return false;
226     }
227     new_weight_param_name_.push_back(cnode->fullname_with_scope() + "_add_param");
228     auto inputs = {weight, add_param_node};
229     auto add_cnode = anf_graph->NewCNode(add_prim_c, inputs);
230     if (add_cnode == nullptr) {
231       MS_LOG(ERROR) << "new add node failed.";
232       return false;
233     }
234     add_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + kUpdateWeightAddNodeNameSuffix);
235     if (node->abstract() != nullptr) {
236       add_cnode->set_abstract(node->abstract()->Clone());
237     }
238     auto manager = Manage(anf_graph);
239     (void)manager->Replace(weight, add_cnode);
240   }
241   if (new_weight_param_name_.size() != constant_cnode_name_.size()) {
242     MS_LOG(ERROR) << "init data name size is not equal user config file name size, new_weight_param_name_: "
243                   << new_weight_param_name_.size() << ", constant_cnode_name_ size: " << constant_cnode_name_.size();
244   }
245   MS_LOG(INFO) << "new_weight_param_name_ size: " << new_weight_param_name_.size()
246                << ", constant_cnode_name_ size: " << constant_cnode_name_.size();
247   return true;
248 }
249 }  // namespace mindspore
250