1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "mindspore/lite/tools/common/custom_ascend_utils.h"
17 #include "mindspore/lite/src/common/log_util.h"
18 #include "mindspore/core/utils/ms_utils_secure.h"
19 #include "mindspore/lite/tools/common/func_graph_utils.h"
20 #include "mindspore/core/ops/tuple_get_item.h"
21 #include "mindspore/lite/src/common/common.h"
22 #include "mindspore/lite/tools/optimizer/common/gllo_utils.h"
23
24 namespace mindspore {
25 namespace {
26 constexpr auto kCustomPrimTypeACL = "ACL";
27 constexpr auto kCustomNodeName = "custom_0";
28 constexpr auto kFuncType = "func_type";
29 constexpr auto kUniqueName = "uniq_name";
30
SaveDynKVCacheInfo(const DynKVCacheSaveInfo & dyn_kv_info,std::map<std::string,ValuePtr> * attr_map)31 void SaveDynKVCacheInfo(const DynKVCacheSaveInfo &dyn_kv_info, std::map<std::string, ValuePtr> *attr_map) {
32 if (!dyn_kv_info.batch_size_dyn && !dyn_kv_info.seq_length_dyn) {
33 return;
34 }
35 std::vector<std::string> dynamic_kv_cache;
36 dynamic_kv_cache.push_back("batch_size_dyn");
37 dynamic_kv_cache.push_back(std::to_string(dyn_kv_info.batch_size_dyn));
38 dynamic_kv_cache.push_back("seq_length_dyn");
39 dynamic_kv_cache.push_back(std::to_string(dyn_kv_info.seq_length_dyn));
40 dynamic_kv_cache.push_back("kv_cache_layout");
41 dynamic_kv_cache.push_back(dyn_kv_info.kv_cache_layout);
42 (*attr_map)["dynamic_kv_cache"] = MakeValue(dynamic_kv_cache);
43 }
44
LoadDynKVCacheInfo(const std::map<std::string,ValuePtr> & attr_map,DynKVCacheSaveInfo * dyn_kv_info)45 void LoadDynKVCacheInfo(const std::map<std::string, ValuePtr> &attr_map, DynKVCacheSaveInfo *dyn_kv_info) {
46 if (dyn_kv_info == nullptr) {
47 return;
48 }
49 auto it = attr_map.find("dynamic_kv_cache");
50 if (it == attr_map.end()) {
51 return;
52 }
53 auto option_pairs = GetValue<std::vector<std::string>>(it->second);
54
55 constexpr size_t pair_size = 2;
56 if (option_pairs.size() % pair_size != 0) {
57 MS_LOG(WARNING) << "Attr dynamic_kv_cache value sequence size " << option_pairs.size()
58 << " is invalid, option paris: " << option_pairs;
59 }
60 for (size_t i = 0; i + 1 < option_pairs.size(); i += pair_size) {
61 auto &key = option_pairs[i];
62 auto &val = option_pairs[i + 1];
63 MS_LOG(INFO) << "Set dynamic_kv_cache option " << key << ": " << val;
64 if (key == "batch_size_dyn") {
65 dyn_kv_info->batch_size_dyn = std::stoi(val);
66 } else if (key == "seq_length_dyn") {
67 dyn_kv_info->seq_length_dyn = std::stoi(val);
68 } else if (key == "kv_cache_layout") {
69 dyn_kv_info->kv_cache_layout = val;
70 }
71 }
72 }
73 } // namespace
74
CreateOmParameter(const FuncGraphPtr & func_graph,const Buffer & om_data,const std::string & graph_name)75 ParameterPtr CustomAscendUtils::CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om_data,
76 const std::string &graph_name) {
77 MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
78 ParameterPtr om_parameter = func_graph->add_parameter();
79 MS_CHECK_TRUE_MSG(om_parameter != nullptr, nullptr, "om_parameter is nullptr.");
80 om_parameter->set_name(graph_name);
81 om_parameter->debug_info()->set_name(graph_name);
82
83 auto type_ptr = TypeIdToType(kNumberTypeUInt8);
84 MS_CHECK_TRUE_MSG(type_ptr != nullptr, nullptr, "type_ptr is nullptr.");
85 ShapeVector shape_vector = {static_cast<int64_t>(om_data.DataSize())};
86 auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
87 MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "abstract_tensor is nullptr.");
88 om_parameter->set_abstract(abstract_tensor);
89
90 auto param_value =
91 std::make_shared<tensor::Tensor>(kNumberTypeUInt8, ShapeVector({static_cast<int64_t>(om_data.DataSize())}));
92 MS_CHECK_TRUE_MSG(param_value != nullptr, nullptr, "param_value is nullptr.");
93 auto tensor_data = param_value->data_c();
94 MS_CHECK_TRUE_MSG(tensor_data != nullptr, nullptr, "New Tensor failed.");
95 if (param_value->Size() < om_data.DataSize()) {
96 MS_LOG(ERROR) << "Dst buff size " << param_value->Size() << " should be greater than src buff size "
97 << om_data.DataSize();
98 return nullptr;
99 }
100 if (common::huge_memcpy(reinterpret_cast<uint8_t *>(tensor_data), param_value->Size(),
101 reinterpret_cast<const uint8_t *>(om_data.Data()), om_data.DataSize()) != EOK) {
102 MS_LOG(ERROR) << "Memcpy om data failed.";
103 return nullptr;
104 }
105 om_parameter->set_default_param(param_value);
106 return om_parameter;
107 }
108
SetCustomOutputs(const FuncGraphPtr & func_graph,const CNodePtr & custom_node)109 bool CustomAscendUtils::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
110 if (outputs_.size() == 1) {
111 auto abstract_tensor = FuncGraphUtils::GetAbstractFromNode(outputs_[0]);
112 if (abstract_tensor == nullptr) {
113 MS_LOG(ERROR) << "Abstract_tensor is nullptr.";
114 return false;
115 }
116 auto abstract_tensor_clone = abstract_tensor->Clone();
117 abstract_tensor_clone->set_name(abstract_tensor->name());
118 custom_node->set_abstract(abstract_tensor_clone);
119 return true;
120 } else {
121 AbstractBasePtrList abstract_list;
122 for (size_t j = 0; j < outputs_.size(); j++) {
123 auto abstract_tensor = FuncGraphUtils::GetAbstractFromNode(outputs_[j]);
124 if (abstract_tensor == nullptr) {
125 MS_LOG(ERROR) << "Abstract tensor is nullptr for output " << j;
126 return false;
127 }
128 auto abstract_tensor_clone = abstract_tensor->Clone();
129 abstract_tensor_clone->set_name(abstract_tensor->name());
130 (void)abstract_list.emplace_back(abstract_tensor_clone);
131 }
132 custom_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
133 }
134 return true;
135 }
136
SetZeroValueRefDatas(const ops::PrimitiveCPtr & primc,const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_infos)137 void CustomAscendUtils::SetZeroValueRefDatas(const ops::PrimitiveCPtr &primc,
138 const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_infos) {
139 ValuePtrList value_ptr_list;
140 for (const auto &item : ref_infos) {
141 (void)value_ptr_list.emplace_back(MakeValue<std::string>(item.first));
142 (void)value_ptr_list.emplace_back(MakeValue(static_cast<uint64_t>(item.second->data_type())));
143 (void)value_ptr_list.emplace_back(MakeValue(item.second->shape_c()));
144 }
145 (void)primc->AddAttr(lite::kNameAttrZeroValRefDatas, MakeValue(value_ptr_list));
146 }
147
GetZeroValueRefDatas(const ops::PrimitiveCPtr & primc,std::vector<std::pair<std::string,tensor::TensorPtr>> * ref_infos)148 bool CustomAscendUtils::GetZeroValueRefDatas(const ops::PrimitiveCPtr &primc,
149 std::vector<std::pair<std::string, tensor::TensorPtr>> *ref_infos) {
150 auto attr = primc->GetAttr(lite::kNameAttrZeroValRefDatas);
151 if (attr == nullptr) {
152 MS_LOG(INFO) << "Not found attr " << lite::kNameAttrZeroValRefDatas << " in custom node";
153 return true;
154 }
155 auto value_ptr_list = GetValue<ValuePtrList>(attr);
156 constexpr size_t every_item_size = 3;
157 if (value_ptr_list.size() % every_item_size != 0) {
158 MS_LOG(ERROR) << "Attr " << lite::kNameAttrZeroValRefDatas << " item count should be multiply of 3, but got "
159 << value_ptr_list.size();
160 return false;
161 }
162 for (size_t i = 0; i < value_ptr_list.size(); i += every_item_size) {
163 auto param_name = GetValue<std::string>(value_ptr_list[i]);
164 auto data_type = static_cast<TypeId>(GetValue<uint64_t>(value_ptr_list[i + 1]));
165 auto param_shape = GetValue<ShapeVector>(value_ptr_list[i + 2]);
166 auto tensor = std::make_shared<tensor::Tensor>(data_type, param_shape);
167 ref_infos->push_back(std::make_pair(param_name, tensor));
168 }
169 return true;
170 }
171
CreateCustomNode(const FuncGraphPtr & func_graph,const ParameterPtr & om_parameter,const std::map<std::string,ValuePtr> & attr_map,const std::vector<std::string> & ref_datas)172 CNodePtr CustomAscendUtils::CreateCustomNode(const FuncGraphPtr &func_graph, const ParameterPtr &om_parameter,
173 const std::map<std::string, ValuePtr> &attr_map,
174 const std::vector<std::string> &ref_datas) {
175 MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
176 auto prim = std::make_shared<mindspore::ops::Custom>();
177 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "New custom op failed.");
178 prim->set_type(kCustomPrimTypeACL);
179 auto prim_c = prim->GetPrim();
180 auto graph_input = func_graph->get_inputs();
181 std::vector<std::pair<std::string, tensor::TensorPtr>> zeor_val_ref_infos;
182 for (auto &item : func_graph->parameters()) {
183 if (item && item->cast<ParameterPtr>() != nullptr) {
184 auto parameter = item->cast<ParameterPtr>();
185 auto param_name = parameter->name();
186 if (std::find(ref_datas.begin(), ref_datas.end(), param_name) != ref_datas.end()) {
187 auto value = parameter->default_param();
188 if (value == nullptr) {
189 continue;
190 }
191 auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
192 if (tensor == nullptr) {
193 continue;
194 }
195 if (IsParameterValueZero(tensor)) {
196 zeor_val_ref_infos.push_back(std::make_pair(param_name, tensor));
197 } else {
198 graph_input.push_back(item);
199 }
200 }
201 }
202 }
203 CNodePtr custom_node = func_graph->NewCNode(prim_c, graph_input);
204 MS_CHECK_TRUE_MSG(custom_node != nullptr, nullptr, "Custom cnode failed.");
205 custom_node->set_fullname_with_scope(kCustomNodeName);
206 custom_node->add_input(om_parameter);
207
208 if (!SetCustomOutputs(func_graph, custom_node)) {
209 MS_LOG(ERROR) << "Set custom outputs failed.";
210 return nullptr;
211 }
212 SetCustomAttrs(prim, attr_map);
213 SetZeroValueRefDatas(prim_c, zeor_val_ref_infos);
214 (void)prim->AddAttr(lite::kNameAttrRefDatas, api::MakeValue(ref_datas));
215 return custom_node;
216 }
217
SetCustomAttrs(const std::shared_ptr<ops::Custom> & prim,const std::map<std::string,ValuePtr> & attr_map)218 void CustomAscendUtils::SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim,
219 const std::map<std::string, ValuePtr> &attr_map) {
220 std::string output_dim_str;
221 for (const auto &item : outputs_) {
222 auto shape = opt::GetAnfNodeOutputShape(item.first, item.second);
223 output_dim_str += std::to_string(shape.size()) + ",";
224 for (const auto &val : shape) {
225 output_dim_str += std::to_string(val) + ",";
226 }
227 }
228 std::vector<uint8_t> output_dim_char(output_dim_str.begin(), output_dim_str.end());
229 std::map<std::string, std::vector<uint8_t>> attrs = {{lite::kOutputShapes, output_dim_char}};
230 prim->set_attr(attrs);
231 prim->AddAttr(kFuncType, api::MakeValue<std::string>("acl_build"));
232 prim->AddAttr(kUniqueName, api::MakeValue<std::string>(lite::kNameCustomAscend));
233 auto prim_c = prim->GetPrim();
234 for (auto &attr : attr_map) {
235 prim_c->AddAttr(attr.first, attr.second);
236 }
237 }
238
CreateMakeTupleGraphOutput(const FuncGraphPtr & func_graph,const CNodePtr & custom_node)239 CNodePtr CustomAscendUtils::CreateMakeTupleGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
240 std::vector<CNodePtr> node_list;
241 for (size_t j = 0; j < outputs_.size(); ++j) {
242 auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
243 if (tuple_get_item_prim_ptr == nullptr) {
244 MS_LOG(ERROR) << "New TupleGetItem failed for output " << j;
245 return nullptr;
246 }
247 auto tuple_get_item_prim_ptr_c = tuple_get_item_prim_ptr->GetPrim();
248 auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr_c);
249 MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, nullptr, "item_prim is nullptr.");
250 auto get_item_value = NewValueNode(MakeValue<int64_t>(j));
251 MS_CHECK_TRUE_MSG(get_item_value != nullptr, nullptr, "item_value is nullptr.");
252 AnfNodePtrList inputs{tuple_get_item_prim, custom_node, get_item_value};
253 CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
254 if (get_item_cnode == nullptr) {
255 MS_LOG(ERROR) << "New get item cnode failed for output " << j;
256 return nullptr;
257 }
258 get_item_cnode->set_fullname_with_scope(custom_node->fullname_with_scope() + "_getitem_" + std::to_string(j));
259 node_list.emplace_back(get_item_cnode);
260 }
261 auto make_tuple_val_node = NewValueNode(prim::kPrimMakeTuple);
262 MS_CHECK_TRUE_MSG(make_tuple_val_node != nullptr, nullptr, "New make tuple val node failed.");
263 AnfNodePtrList new_inputs = {make_tuple_val_node};
264 new_inputs.insert(new_inputs.end(), node_list.begin(), node_list.end());
265 auto make_tuple_cnode = func_graph->NewCNode(new_inputs);
266 MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "New make tuple cnode failed.");
267 return make_tuple_cnode;
268 }
269
ModifyGraphByCustomNode(const FuncGraphPtr & func_graph,const CNodePtr & custom_node)270 bool CustomAscendUtils::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
271 auto manager = Manage(func_graph, true);
272 if (manager == nullptr) {
273 MS_LOG(ERROR) << "Manager is nullptr";
274 return false;
275 }
276 AnfNodePtr return_input = func_graph->output();
277 MS_CHECK_TRUE_MSG(return_input != nullptr, lite::RET_ERROR, "return input is nullptr.");
278 if (outputs_.size() == 1) {
279 if (!manager->Replace(return_input, custom_node)) {
280 MS_LOG(ERROR) << "Replace node failed.";
281 return false;
282 }
283 } else {
284 auto make_tuple_node = CreateMakeTupleGraphOutput(func_graph, custom_node);
285 MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, lite::RET_ERROR, "Create make tuple cnode failed.");
286 if (!manager->Replace(return_input, make_tuple_node)) {
287 MS_LOG(ERROR) << "Replace node failed for outputs of graph.";
288 return false;
289 }
290 }
291 std::vector<AnfNodePtr> new_parameters;
292 auto node_users = manager->node_users();
293 for (auto &item : func_graph->parameters()) {
294 auto parameter = item->cast<ParameterPtr>();
295 if (!parameter) {
296 continue;
297 }
298 if (!parameter->has_default()) {
299 new_parameters.push_back(parameter);
300 } else {
301 auto users = node_users.find(item);
302 if (!users->second.empty()) {
303 new_parameters.push_back(item);
304 }
305 }
306 }
307 manager->SetParameters(func_graph, new_parameters);
308 MS_LOG(DEBUG) << "Modify graph by custom node success.";
309 return true;
310 }
311
IsParameterValueZero(const tensor::TensorPtr & tensor)312 bool CustomAscendUtils::IsParameterValueZero(const tensor::TensorPtr &tensor) {
313 if (tensor == nullptr) {
314 return false;
315 }
316 auto size = tensor->Size();
317 auto count = size / sizeof(uint64_t);
318 auto data_u8 = reinterpret_cast<uint8_t *>(tensor->data_c());
319 if (data_u8 == nullptr) {
320 return false;
321 }
322 auto data_u64 = reinterpret_cast<uint64_t *>(tensor->data_c());
323 for (size_t i = 0; i < count; i++) {
324 if (data_u64[i] != 0) {
325 return false;
326 }
327 }
328 for (size_t i = count * sizeof(uint64_t); i < size; i++) {
329 if (data_u8[i] != 0) {
330 return false;
331 }
332 }
333 return true;
334 }
335
CreateCustomFuncGraph(const FuncGraphPtr & func_graph,const Buffer & model_cache,const std::string & graph_name,const std::map<std::string,ValuePtr> & attr_map,const std::vector<std::string> & ref_datas,const DynKVCacheSaveInfo & dyn_kv_info)336 bool CustomAscendUtils::CreateCustomFuncGraph(const FuncGraphPtr &func_graph, const Buffer &model_cache,
337 const std::string &graph_name,
338 const std::map<std::string, ValuePtr> &attr_map,
339 const std::vector<std::string> &ref_datas,
340 const DynKVCacheSaveInfo &dyn_kv_info) {
341 CustomAscendUtils utils;
342 utils.outputs_ = opt::GetNodeInputs(func_graph->get_return());
343 auto om_parameter = CreateOmParameter(func_graph, model_cache, graph_name);
344 if (om_parameter == nullptr) {
345 MS_LOG(ERROR) << "Create custom parameter failed";
346 return false;
347 }
348 std::map<std::string, ValuePtr> attr_map_new = attr_map;
349 SaveDynKVCacheInfo(dyn_kv_info, &attr_map_new);
350 auto cnode = utils.CreateCustomNode(func_graph, om_parameter, attr_map_new, ref_datas);
351 if (cnode == nullptr) {
352 MS_LOG(ERROR) << "Create custom cnode failed";
353 return false;
354 }
355 if (!utils.ModifyGraphByCustomNode(func_graph, cnode)) {
356 return false;
357 }
358 return true;
359 }
360
GetCustomNode(const FuncGraphPtr & func_graph)361 CNodePtr CustomAscendUtils::GetCustomNode(const FuncGraphPtr &func_graph) {
362 if (func_graph == nullptr) {
363 return nullptr;
364 }
365 auto nodes = func_graph->TopoSort(func_graph->get_return());
366 if (nodes.empty()) {
367 MS_LOG(WARNING) << "There are no nodes in the graph";
368 return nullptr;
369 }
370 CNodePtr custom_node = nullptr;
371 size_t cnode_count = 0;
372 for (auto &node : nodes) {
373 auto cnode = node->cast<CNodePtr>();
374 if (!cnode || !AnfUtils::IsRealKernel(cnode)) {
375 continue;
376 }
377 std::string kernel_name = AnfUtils::GetCNodeName(cnode);
378 if (kernel_name != lite::kNameCustomAscend) {
379 return nullptr;
380 }
381 cnode_count += 1;
382 if (cnode_count > 1) {
383 MS_LOG(ERROR) << "Only support one " << lite::kNameCustomAscend << " node, but got " << kernel_name << ", node "
384 << cnode->fullname_with_scope();
385 return nullptr;
386 }
387 auto inputs = cnode->inputs();
388 if (inputs.size() < 1) {
389 MS_LOG(ERROR) << "Custom node input count " << inputs.size() << " invalid";
390 return nullptr;
391 }
392 custom_node = cnode;
393 }
394 return custom_node;
395 }
396
IsCustomFuncGraph(const FuncGraphPtr & func_graph)397 bool CustomAscendUtils::IsCustomFuncGraph(const FuncGraphPtr &func_graph) {
398 return GetCustomNode(func_graph) != nullptr;
399 }
400
ParseCustomFuncGraph(const FuncGraphPtr & func_graph,tensor::TensorPtr * model_cache,std::string * graph_name,std::map<std::string,ValuePtr> * attr_map,std::vector<std::pair<std::string,tensor::TensorPtr>> * ref_datas,DynKVCacheSaveInfo * dyn_kv_info)401 bool CustomAscendUtils::ParseCustomFuncGraph(const FuncGraphPtr &func_graph, tensor::TensorPtr *model_cache,
402 std::string *graph_name, std::map<std::string, ValuePtr> *attr_map,
403 std::vector<std::pair<std::string, tensor::TensorPtr>> *ref_datas,
404 DynKVCacheSaveInfo *dyn_kv_info) {
405 MS_ERROR_IF_NULL_W_RET_VAL(func_graph, false);
406 MS_ERROR_IF_NULL_W_RET_VAL(model_cache, false);
407 MS_ERROR_IF_NULL_W_RET_VAL(graph_name, false);
408 MS_ERROR_IF_NULL_W_RET_VAL(attr_map, false);
409 MS_ERROR_IF_NULL_W_RET_VAL(ref_datas, false);
410 auto custom_node = GetCustomNode(func_graph);
411 if (custom_node == nullptr) {
412 MS_LOG(ERROR) << "Cannot find Custom node, or other real node find in the graph";
413 return false;
414 }
415 auto inputs = custom_node->inputs();
416 if (inputs.size() < 1) {
417 MS_LOG(ERROR) << "Custom node input count " << inputs.size() << " invalid";
418 return false;
419 }
420 auto input_last = *inputs.rbegin();
421 if (!input_last) {
422 MS_LOG(ERROR) << "Custom node last input is nullptr";
423 return false;
424 }
425 auto tensor = FuncGraphUtils::GetParameterConstValue(input_last);
426 if (tensor == nullptr) {
427 MS_LOG(ERROR) << "Failed to cast parameter value to Tensor";
428 return false;
429 }
430 if (tensor->data_c() == nullptr || tensor->Size() == 0) {
431 MS_LOG(ERROR) << "Custom node tensor data is empty";
432 return false;
433 }
434 auto prim = GetValueNode<PrimitivePtr>(custom_node->input(0));
435 if (!prim) {
436 MS_LOG(ERROR) << "Primitive of cnode " << custom_node->fullname_with_scope() << " cannot be nullptr";
437 return false;
438 }
439 for (auto &attr : prim->attrs()) {
440 (*attr_map)[attr.first] = attr.second;
441 }
442 auto attr_ref_datas = prim->GetAttr(lite::kNameAttrRefDatas);
443 if (attr_ref_datas) {
444 auto ref_datas_names = GetValue<std::vector<std::string>>(attr_ref_datas);
445 std::vector<std::pair<std::string, tensor::TensorPtr>> zero_val_ref_infos;
446 if (!GetZeroValueRefDatas(prim, &zero_val_ref_infos)) {
447 MS_LOG(ERROR) << "Failed to get zero value ref data";
448 return false;
449 }
450 auto parameters = func_graph->parameters();
451 std::vector<AnfNodePtr> new_parameters = func_graph->get_inputs();
452 for (auto &ref_name : ref_datas_names) {
453 auto it = std::find_if(zero_val_ref_infos.begin(), zero_val_ref_infos.end(),
454 [&ref_name](const auto &info) { return info.first == ref_name; });
455 if (it != zero_val_ref_infos.end()) {
456 ref_datas->push_back(std::make_pair(ref_name, it->second));
457 continue;
458 }
459 auto p_it = std::find_if(parameters.begin(), parameters.end(),
460 [&ref_name](auto &item) { return item->fullname_with_scope() == ref_name; });
461 if (p_it == parameters.end() || *p_it == nullptr) {
462 MS_LOG(ERROR) << "Cannot find RefData parameter " << ref_name;
463 return false;
464 }
465 auto ref_tensor = FuncGraphUtils::GetParameterConstValue(*p_it);
466 if (ref_tensor == nullptr) {
467 MS_LOG(ERROR) << "Failed to find tensor value for parameter " << ref_name;
468 return false;
469 }
470 ref_datas->push_back(std::make_pair(ref_name, ref_tensor));
471 }
472 }
473 LoadDynKVCacheInfo(*attr_map, dyn_kv_info);
474 *model_cache = tensor;
475 *graph_name = input_last->fullname_with_scope();
476 return true;
477 }
478 } // namespace mindspore
479