• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "mindspore/lite/tools/common/custom_ascend_utils.h"
17 #include "mindspore/lite/src/common/log_util.h"
18 #include "mindspore/core/utils/ms_utils_secure.h"
19 #include "mindspore/lite/tools/common/func_graph_utils.h"
20 #include "mindspore/core/ops/tuple_get_item.h"
21 #include "mindspore/lite/src/common/common.h"
22 #include "mindspore/lite/tools/optimizer/common/gllo_utils.h"
23 
24 namespace mindspore {
25 namespace {
26 constexpr auto kCustomPrimTypeACL = "ACL";
27 constexpr auto kCustomNodeName = "custom_0";
28 constexpr auto kFuncType = "func_type";
29 constexpr auto kUniqueName = "uniq_name";
30 
SaveDynKVCacheInfo(const DynKVCacheSaveInfo & dyn_kv_info,std::map<std::string,ValuePtr> * attr_map)31 void SaveDynKVCacheInfo(const DynKVCacheSaveInfo &dyn_kv_info, std::map<std::string, ValuePtr> *attr_map) {
32   if (!dyn_kv_info.batch_size_dyn && !dyn_kv_info.seq_length_dyn) {
33     return;
34   }
35   std::vector<std::string> dynamic_kv_cache;
36   dynamic_kv_cache.push_back("batch_size_dyn");
37   dynamic_kv_cache.push_back(std::to_string(dyn_kv_info.batch_size_dyn));
38   dynamic_kv_cache.push_back("seq_length_dyn");
39   dynamic_kv_cache.push_back(std::to_string(dyn_kv_info.seq_length_dyn));
40   dynamic_kv_cache.push_back("kv_cache_layout");
41   dynamic_kv_cache.push_back(dyn_kv_info.kv_cache_layout);
42   (*attr_map)["dynamic_kv_cache"] = MakeValue(dynamic_kv_cache);
43 }
44 
LoadDynKVCacheInfo(const std::map<std::string,ValuePtr> & attr_map,DynKVCacheSaveInfo * dyn_kv_info)45 void LoadDynKVCacheInfo(const std::map<std::string, ValuePtr> &attr_map, DynKVCacheSaveInfo *dyn_kv_info) {
46   if (dyn_kv_info == nullptr) {
47     return;
48   }
49   auto it = attr_map.find("dynamic_kv_cache");
50   if (it == attr_map.end()) {
51     return;
52   }
53   auto option_pairs = GetValue<std::vector<std::string>>(it->second);
54 
55   constexpr size_t pair_size = 2;
56   if (option_pairs.size() % pair_size != 0) {
57     MS_LOG(WARNING) << "Attr dynamic_kv_cache value sequence size " << option_pairs.size()
58                     << " is invalid, option paris: " << option_pairs;
59   }
60   for (size_t i = 0; i + 1 < option_pairs.size(); i += pair_size) {
61     auto &key = option_pairs[i];
62     auto &val = option_pairs[i + 1];
63     MS_LOG(INFO) << "Set dynamic_kv_cache option " << key << ": " << val;
64     if (key == "batch_size_dyn") {
65       dyn_kv_info->batch_size_dyn = std::stoi(val);
66     } else if (key == "seq_length_dyn") {
67       dyn_kv_info->seq_length_dyn = std::stoi(val);
68     } else if (key == "kv_cache_layout") {
69       dyn_kv_info->kv_cache_layout = val;
70     }
71   }
72 }
73 }  // namespace
74 
CreateOmParameter(const FuncGraphPtr & func_graph,const Buffer & om_data,const std::string & graph_name)75 ParameterPtr CustomAscendUtils::CreateOmParameter(const FuncGraphPtr &func_graph, const Buffer &om_data,
76                                                   const std::string &graph_name) {
77   MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
78   ParameterPtr om_parameter = func_graph->add_parameter();
79   MS_CHECK_TRUE_MSG(om_parameter != nullptr, nullptr, "om_parameter is nullptr.");
80   om_parameter->set_name(graph_name);
81   om_parameter->debug_info()->set_name(graph_name);
82 
83   auto type_ptr = TypeIdToType(kNumberTypeUInt8);
84   MS_CHECK_TRUE_MSG(type_ptr != nullptr, nullptr, "type_ptr is nullptr.");
85   ShapeVector shape_vector = {static_cast<int64_t>(om_data.DataSize())};
86   auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
87   MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, nullptr, "abstract_tensor is nullptr.");
88   om_parameter->set_abstract(abstract_tensor);
89 
90   auto param_value =
91     std::make_shared<tensor::Tensor>(kNumberTypeUInt8, ShapeVector({static_cast<int64_t>(om_data.DataSize())}));
92   MS_CHECK_TRUE_MSG(param_value != nullptr, nullptr, "param_value is nullptr.");
93   auto tensor_data = param_value->data_c();
94   MS_CHECK_TRUE_MSG(tensor_data != nullptr, nullptr, "New Tensor failed.");
95   if (param_value->Size() < om_data.DataSize()) {
96     MS_LOG(ERROR) << "Dst buff size  " << param_value->Size() << " should be greater than src buff size "
97                   << om_data.DataSize();
98     return nullptr;
99   }
100   if (common::huge_memcpy(reinterpret_cast<uint8_t *>(tensor_data), param_value->Size(),
101                           reinterpret_cast<const uint8_t *>(om_data.Data()), om_data.DataSize()) != EOK) {
102     MS_LOG(ERROR) << "Memcpy om data failed.";
103     return nullptr;
104   }
105   om_parameter->set_default_param(param_value);
106   return om_parameter;
107 }
108 
SetCustomOutputs(const FuncGraphPtr & func_graph,const CNodePtr & custom_node)109 bool CustomAscendUtils::SetCustomOutputs(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
110   if (outputs_.size() == 1) {
111     auto abstract_tensor = FuncGraphUtils::GetAbstractFromNode(outputs_[0]);
112     if (abstract_tensor == nullptr) {
113       MS_LOG(ERROR) << "Abstract_tensor is nullptr.";
114       return false;
115     }
116     auto abstract_tensor_clone = abstract_tensor->Clone();
117     abstract_tensor_clone->set_name(abstract_tensor->name());
118     custom_node->set_abstract(abstract_tensor_clone);
119     return true;
120   } else {
121     AbstractBasePtrList abstract_list;
122     for (size_t j = 0; j < outputs_.size(); j++) {
123       auto abstract_tensor = FuncGraphUtils::GetAbstractFromNode(outputs_[j]);
124       if (abstract_tensor == nullptr) {
125         MS_LOG(ERROR) << "Abstract tensor is nullptr for output " << j;
126         return false;
127       }
128       auto abstract_tensor_clone = abstract_tensor->Clone();
129       abstract_tensor_clone->set_name(abstract_tensor->name());
130       (void)abstract_list.emplace_back(abstract_tensor_clone);
131     }
132     custom_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
133   }
134   return true;
135 }
136 
SetZeroValueRefDatas(const ops::PrimitiveCPtr & primc,const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_infos)137 void CustomAscendUtils::SetZeroValueRefDatas(const ops::PrimitiveCPtr &primc,
138                                              const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_infos) {
139   ValuePtrList value_ptr_list;
140   for (const auto &item : ref_infos) {
141     (void)value_ptr_list.emplace_back(MakeValue<std::string>(item.first));
142     (void)value_ptr_list.emplace_back(MakeValue(static_cast<uint64_t>(item.second->data_type())));
143     (void)value_ptr_list.emplace_back(MakeValue(item.second->shape_c()));
144   }
145   (void)primc->AddAttr(lite::kNameAttrZeroValRefDatas, MakeValue(value_ptr_list));
146 }
147 
GetZeroValueRefDatas(const ops::PrimitiveCPtr & primc,std::vector<std::pair<std::string,tensor::TensorPtr>> * ref_infos)148 bool CustomAscendUtils::GetZeroValueRefDatas(const ops::PrimitiveCPtr &primc,
149                                              std::vector<std::pair<std::string, tensor::TensorPtr>> *ref_infos) {
150   auto attr = primc->GetAttr(lite::kNameAttrZeroValRefDatas);
151   if (attr == nullptr) {
152     MS_LOG(INFO) << "Not found attr " << lite::kNameAttrZeroValRefDatas << " in custom node";
153     return true;
154   }
155   auto value_ptr_list = GetValue<ValuePtrList>(attr);
156   constexpr size_t every_item_size = 3;
157   if (value_ptr_list.size() % every_item_size != 0) {
158     MS_LOG(ERROR) << "Attr " << lite::kNameAttrZeroValRefDatas << " item count should be multiply of 3, but got "
159                   << value_ptr_list.size();
160     return false;
161   }
162   for (size_t i = 0; i < value_ptr_list.size(); i += every_item_size) {
163     auto param_name = GetValue<std::string>(value_ptr_list[i]);
164     auto data_type = static_cast<TypeId>(GetValue<uint64_t>(value_ptr_list[i + 1]));
165     auto param_shape = GetValue<ShapeVector>(value_ptr_list[i + 2]);
166     auto tensor = std::make_shared<tensor::Tensor>(data_type, param_shape);
167     ref_infos->push_back(std::make_pair(param_name, tensor));
168   }
169   return true;
170 }
171 
CreateCustomNode(const FuncGraphPtr & func_graph,const ParameterPtr & om_parameter,const std::map<std::string,ValuePtr> & attr_map,const std::vector<std::string> & ref_datas)172 CNodePtr CustomAscendUtils::CreateCustomNode(const FuncGraphPtr &func_graph, const ParameterPtr &om_parameter,
173                                              const std::map<std::string, ValuePtr> &attr_map,
174                                              const std::vector<std::string> &ref_datas) {
175   MS_CHECK_TRUE_MSG(func_graph != nullptr, nullptr, "func_graph is nullptr.");
176   auto prim = std::make_shared<mindspore::ops::Custom>();
177   MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "New custom op failed.");
178   prim->set_type(kCustomPrimTypeACL);
179   auto prim_c = prim->GetPrim();
180   auto graph_input = func_graph->get_inputs();
181   std::vector<std::pair<std::string, tensor::TensorPtr>> zeor_val_ref_infos;
182   for (auto &item : func_graph->parameters()) {
183     if (item && item->cast<ParameterPtr>() != nullptr) {
184       auto parameter = item->cast<ParameterPtr>();
185       auto param_name = parameter->name();
186       if (std::find(ref_datas.begin(), ref_datas.end(), param_name) != ref_datas.end()) {
187         auto value = parameter->default_param();
188         if (value == nullptr) {
189           continue;
190         }
191         auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
192         if (tensor == nullptr) {
193           continue;
194         }
195         if (IsParameterValueZero(tensor)) {
196           zeor_val_ref_infos.push_back(std::make_pair(param_name, tensor));
197         } else {
198           graph_input.push_back(item);
199         }
200       }
201     }
202   }
203   CNodePtr custom_node = func_graph->NewCNode(prim_c, graph_input);
204   MS_CHECK_TRUE_MSG(custom_node != nullptr, nullptr, "Custom cnode failed.");
205   custom_node->set_fullname_with_scope(kCustomNodeName);
206   custom_node->add_input(om_parameter);
207 
208   if (!SetCustomOutputs(func_graph, custom_node)) {
209     MS_LOG(ERROR) << "Set custom outputs failed.";
210     return nullptr;
211   }
212   SetCustomAttrs(prim, attr_map);
213   SetZeroValueRefDatas(prim_c, zeor_val_ref_infos);
214   (void)prim->AddAttr(lite::kNameAttrRefDatas, api::MakeValue(ref_datas));
215   return custom_node;
216 }
217 
SetCustomAttrs(const std::shared_ptr<ops::Custom> & prim,const std::map<std::string,ValuePtr> & attr_map)218 void CustomAscendUtils::SetCustomAttrs(const std::shared_ptr<ops::Custom> &prim,
219                                        const std::map<std::string, ValuePtr> &attr_map) {
220   std::string output_dim_str;
221   for (const auto &item : outputs_) {
222     auto shape = opt::GetAnfNodeOutputShape(item.first, item.second);
223     output_dim_str += std::to_string(shape.size()) + ",";
224     for (const auto &val : shape) {
225       output_dim_str += std::to_string(val) + ",";
226     }
227   }
228   std::vector<uint8_t> output_dim_char(output_dim_str.begin(), output_dim_str.end());
229   std::map<std::string, std::vector<uint8_t>> attrs = {{lite::kOutputShapes, output_dim_char}};
230   prim->set_attr(attrs);
231   prim->AddAttr(kFuncType, api::MakeValue<std::string>("acl_build"));
232   prim->AddAttr(kUniqueName, api::MakeValue<std::string>(lite::kNameCustomAscend));
233   auto prim_c = prim->GetPrim();
234   for (auto &attr : attr_map) {
235     prim_c->AddAttr(attr.first, attr.second);
236   }
237 }
238 
CreateMakeTupleGraphOutput(const FuncGraphPtr & func_graph,const CNodePtr & custom_node)239 CNodePtr CustomAscendUtils::CreateMakeTupleGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
240   std::vector<CNodePtr> node_list;
241   for (size_t j = 0; j < outputs_.size(); ++j) {
242     auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
243     if (tuple_get_item_prim_ptr == nullptr) {
244       MS_LOG(ERROR) << "New TupleGetItem failed for output " << j;
245       return nullptr;
246     }
247     auto tuple_get_item_prim_ptr_c = tuple_get_item_prim_ptr->GetPrim();
248     auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr_c);
249     MS_CHECK_TRUE_MSG(tuple_get_item_prim != nullptr, nullptr, "item_prim is nullptr.");
250     auto get_item_value = NewValueNode(MakeValue<int64_t>(j));
251     MS_CHECK_TRUE_MSG(get_item_value != nullptr, nullptr, "item_value is nullptr.");
252     AnfNodePtrList inputs{tuple_get_item_prim, custom_node, get_item_value};
253     CNodePtr get_item_cnode = func_graph->NewCNode(inputs);
254     if (get_item_cnode == nullptr) {
255       MS_LOG(ERROR) << "New get item cnode failed for output " << j;
256       return nullptr;
257     }
258     get_item_cnode->set_fullname_with_scope(custom_node->fullname_with_scope() + "_getitem_" + std::to_string(j));
259     node_list.emplace_back(get_item_cnode);
260   }
261   auto make_tuple_val_node = NewValueNode(prim::kPrimMakeTuple);
262   MS_CHECK_TRUE_MSG(make_tuple_val_node != nullptr, nullptr, "New make tuple val node failed.");
263   AnfNodePtrList new_inputs = {make_tuple_val_node};
264   new_inputs.insert(new_inputs.end(), node_list.begin(), node_list.end());
265   auto make_tuple_cnode = func_graph->NewCNode(new_inputs);
266   MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, nullptr, "New make tuple cnode failed.");
267   return make_tuple_cnode;
268 }
269 
ModifyGraphByCustomNode(const FuncGraphPtr & func_graph,const CNodePtr & custom_node)270 bool CustomAscendUtils::ModifyGraphByCustomNode(const FuncGraphPtr &func_graph, const CNodePtr &custom_node) {
271   auto manager = Manage(func_graph, true);
272   if (manager == nullptr) {
273     MS_LOG(ERROR) << "Manager is nullptr";
274     return false;
275   }
276   AnfNodePtr return_input = func_graph->output();
277   MS_CHECK_TRUE_MSG(return_input != nullptr, lite::RET_ERROR, "return input is nullptr.");
278   if (outputs_.size() == 1) {
279     if (!manager->Replace(return_input, custom_node)) {
280       MS_LOG(ERROR) << "Replace node failed.";
281       return false;
282     }
283   } else {
284     auto make_tuple_node = CreateMakeTupleGraphOutput(func_graph, custom_node);
285     MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, lite::RET_ERROR, "Create make tuple cnode failed.");
286     if (!manager->Replace(return_input, make_tuple_node)) {
287       MS_LOG(ERROR) << "Replace node failed for outputs of graph.";
288       return false;
289     }
290   }
291   std::vector<AnfNodePtr> new_parameters;
292   auto node_users = manager->node_users();
293   for (auto &item : func_graph->parameters()) {
294     auto parameter = item->cast<ParameterPtr>();
295     if (!parameter) {
296       continue;
297     }
298     if (!parameter->has_default()) {
299       new_parameters.push_back(parameter);
300     } else {
301       auto users = node_users.find(item);
302       if (!users->second.empty()) {
303         new_parameters.push_back(item);
304       }
305     }
306   }
307   manager->SetParameters(func_graph, new_parameters);
308   MS_LOG(DEBUG) << "Modify graph by custom node success.";
309   return true;
310 }
311 
IsParameterValueZero(const tensor::TensorPtr & tensor)312 bool CustomAscendUtils::IsParameterValueZero(const tensor::TensorPtr &tensor) {
313   if (tensor == nullptr) {
314     return false;
315   }
316   auto size = tensor->Size();
317   auto count = size / sizeof(uint64_t);
318   auto data_u8 = reinterpret_cast<uint8_t *>(tensor->data_c());
319   if (data_u8 == nullptr) {
320     return false;
321   }
322   auto data_u64 = reinterpret_cast<uint64_t *>(tensor->data_c());
323   for (size_t i = 0; i < count; i++) {
324     if (data_u64[i] != 0) {
325       return false;
326     }
327   }
328   for (size_t i = count * sizeof(uint64_t); i < size; i++) {
329     if (data_u8[i] != 0) {
330       return false;
331     }
332   }
333   return true;
334 }
335 
CreateCustomFuncGraph(const FuncGraphPtr & func_graph,const Buffer & model_cache,const std::string & graph_name,const std::map<std::string,ValuePtr> & attr_map,const std::vector<std::string> & ref_datas,const DynKVCacheSaveInfo & dyn_kv_info)336 bool CustomAscendUtils::CreateCustomFuncGraph(const FuncGraphPtr &func_graph, const Buffer &model_cache,
337                                               const std::string &graph_name,
338                                               const std::map<std::string, ValuePtr> &attr_map,
339                                               const std::vector<std::string> &ref_datas,
340                                               const DynKVCacheSaveInfo &dyn_kv_info) {
341   CustomAscendUtils utils;
342   utils.outputs_ = opt::GetNodeInputs(func_graph->get_return());
343   auto om_parameter = CreateOmParameter(func_graph, model_cache, graph_name);
344   if (om_parameter == nullptr) {
345     MS_LOG(ERROR) << "Create custom parameter failed";
346     return false;
347   }
348   std::map<std::string, ValuePtr> attr_map_new = attr_map;
349   SaveDynKVCacheInfo(dyn_kv_info, &attr_map_new);
350   auto cnode = utils.CreateCustomNode(func_graph, om_parameter, attr_map_new, ref_datas);
351   if (cnode == nullptr) {
352     MS_LOG(ERROR) << "Create custom cnode failed";
353     return false;
354   }
355   if (!utils.ModifyGraphByCustomNode(func_graph, cnode)) {
356     return false;
357   }
358   return true;
359 }
360 
GetCustomNode(const FuncGraphPtr & func_graph)361 CNodePtr CustomAscendUtils::GetCustomNode(const FuncGraphPtr &func_graph) {
362   if (func_graph == nullptr) {
363     return nullptr;
364   }
365   auto nodes = func_graph->TopoSort(func_graph->get_return());
366   if (nodes.empty()) {
367     MS_LOG(WARNING) << "There are no nodes in the graph";
368     return nullptr;
369   }
370   CNodePtr custom_node = nullptr;
371   size_t cnode_count = 0;
372   for (auto &node : nodes) {
373     auto cnode = node->cast<CNodePtr>();
374     if (!cnode || !AnfUtils::IsRealKernel(cnode)) {
375       continue;
376     }
377     std::string kernel_name = AnfUtils::GetCNodeName(cnode);
378     if (kernel_name != lite::kNameCustomAscend) {
379       return nullptr;
380     }
381     cnode_count += 1;
382     if (cnode_count > 1) {
383       MS_LOG(ERROR) << "Only support one " << lite::kNameCustomAscend << " node, but got " << kernel_name << ", node "
384                     << cnode->fullname_with_scope();
385       return nullptr;
386     }
387     auto inputs = cnode->inputs();
388     if (inputs.size() < 1) {
389       MS_LOG(ERROR) << "Custom node input count " << inputs.size() << " invalid";
390       return nullptr;
391     }
392     custom_node = cnode;
393   }
394   return custom_node;
395 }
396 
IsCustomFuncGraph(const FuncGraphPtr & func_graph)397 bool CustomAscendUtils::IsCustomFuncGraph(const FuncGraphPtr &func_graph) {
398   return GetCustomNode(func_graph) != nullptr;
399 }
400 
ParseCustomFuncGraph(const FuncGraphPtr & func_graph,tensor::TensorPtr * model_cache,std::string * graph_name,std::map<std::string,ValuePtr> * attr_map,std::vector<std::pair<std::string,tensor::TensorPtr>> * ref_datas,DynKVCacheSaveInfo * dyn_kv_info)401 bool CustomAscendUtils::ParseCustomFuncGraph(const FuncGraphPtr &func_graph, tensor::TensorPtr *model_cache,
402                                              std::string *graph_name, std::map<std::string, ValuePtr> *attr_map,
403                                              std::vector<std::pair<std::string, tensor::TensorPtr>> *ref_datas,
404                                              DynKVCacheSaveInfo *dyn_kv_info) {
405   MS_ERROR_IF_NULL_W_RET_VAL(func_graph, false);
406   MS_ERROR_IF_NULL_W_RET_VAL(model_cache, false);
407   MS_ERROR_IF_NULL_W_RET_VAL(graph_name, false);
408   MS_ERROR_IF_NULL_W_RET_VAL(attr_map, false);
409   MS_ERROR_IF_NULL_W_RET_VAL(ref_datas, false);
410   auto custom_node = GetCustomNode(func_graph);
411   if (custom_node == nullptr) {
412     MS_LOG(ERROR) << "Cannot find Custom node, or other real node find in the graph";
413     return false;
414   }
415   auto inputs = custom_node->inputs();
416   if (inputs.size() < 1) {
417     MS_LOG(ERROR) << "Custom node input count " << inputs.size() << " invalid";
418     return false;
419   }
420   auto input_last = *inputs.rbegin();
421   if (!input_last) {
422     MS_LOG(ERROR) << "Custom node last input is nullptr";
423     return false;
424   }
425   auto tensor = FuncGraphUtils::GetParameterConstValue(input_last);
426   if (tensor == nullptr) {
427     MS_LOG(ERROR) << "Failed to cast parameter value to Tensor";
428     return false;
429   }
430   if (tensor->data_c() == nullptr || tensor->Size() == 0) {
431     MS_LOG(ERROR) << "Custom node tensor data is empty";
432     return false;
433   }
434   auto prim = GetValueNode<PrimitivePtr>(custom_node->input(0));
435   if (!prim) {
436     MS_LOG(ERROR) << "Primitive of cnode " << custom_node->fullname_with_scope() << " cannot be nullptr";
437     return false;
438   }
439   for (auto &attr : prim->attrs()) {
440     (*attr_map)[attr.first] = attr.second;
441   }
442   auto attr_ref_datas = prim->GetAttr(lite::kNameAttrRefDatas);
443   if (attr_ref_datas) {
444     auto ref_datas_names = GetValue<std::vector<std::string>>(attr_ref_datas);
445     std::vector<std::pair<std::string, tensor::TensorPtr>> zero_val_ref_infos;
446     if (!GetZeroValueRefDatas(prim, &zero_val_ref_infos)) {
447       MS_LOG(ERROR) << "Failed to get zero value ref data";
448       return false;
449     }
450     auto parameters = func_graph->parameters();
451     std::vector<AnfNodePtr> new_parameters = func_graph->get_inputs();
452     for (auto &ref_name : ref_datas_names) {
453       auto it = std::find_if(zero_val_ref_infos.begin(), zero_val_ref_infos.end(),
454                              [&ref_name](const auto &info) { return info.first == ref_name; });
455       if (it != zero_val_ref_infos.end()) {
456         ref_datas->push_back(std::make_pair(ref_name, it->second));
457         continue;
458       }
459       auto p_it = std::find_if(parameters.begin(), parameters.end(),
460                                [&ref_name](auto &item) { return item->fullname_with_scope() == ref_name; });
461       if (p_it == parameters.end() || *p_it == nullptr) {
462         MS_LOG(ERROR) << "Cannot find RefData parameter " << ref_name;
463         return false;
464       }
465       auto ref_tensor = FuncGraphUtils::GetParameterConstValue(*p_it);
466       if (ref_tensor == nullptr) {
467         MS_LOG(ERROR) << "Failed to find tensor value for parameter " << ref_name;
468         return false;
469       }
470       ref_datas->push_back(std::make_pair(ref_name, ref_tensor));
471     }
472   }
473   LoadDynKVCacheInfo(*attr_map, dyn_kv_info);
474   *model_cache = tensor;
475   *graph_name = input_last->fullname_with_scope();
476   return true;
477 }
478 }  // namespace mindspore
479