• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "extendrt/delegate/ascend_ge/ge_graph_executor.h"
18 #include <tuple>
19 #include <algorithm>
20 #include <utility>
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "extendrt/delegate/factory.h"
23 #include "include/common/utils/scoped_long_running.h"
24 #include "include/api/context.h"
25 #include "include/api/status.h"
26 #include "include/transform/graph_ir/utils.h"
27 #include "include/backend/device_type.h"
28 #include "runtime/device/ms_device_shape_transfer.h"
29 #include "src/common/common.h"
30 #include "src/common/file_utils.h"
31 #include "cxx_api/acl_utils.h"
32 #include "mindspore/core/utils/ms_utils_secure.h"
33 #include "tools/optimizer/common/gllo_utils.h"
34 #include "tools/optimizer/graph/remove_load_pass.h"
35 #include "src/extendrt/utils/func_graph_utils.h"
36 #include "transform/graph_ir/transform_util.h"
37 #include "flow_graph/data_flow.h"
38 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
39 #include "tools/graph_kernel/converter/graph_kernel_optimization.h"
40 #endif
41 #include "src/extendrt/utils/tensor_utils.h"
42 #include "external/ge_common/ge_api_error_codes.h"
43 #include "src/extendrt/delegate/ascend_ge/aoe_api_tune_process.h"
44 #include "extendrt/delegate/ascend_ge/ge_utils.h"
45 #include "extendrt/delegate/ascend_ge/ge_dynamic_utils.h"
46 #include "mindspore/core/ops/lite_ops.h"
47 #include "mindspore/core/ops/nn_optimizer_ops.h"
48 #include "mindspore/lite/tools/common/string_util.h"
49 #include "mindspore/lite/src/extendrt/cxx_api/file_utils.h"
50 #include "mindspore/core/ops/custom.h"
51 #include "mindspore/lite/src/common/common.h"
52 #include "mindspore/lite/tools/common/custom_ascend_utils.h"
53 #include "op_proto/inc/array_ops.h"
54 #include "op_proto/inc/elewise_calculation_ops.h"
55 #include "mindspore/lite/tools/optimizer/graph/attr_to_args_pass.h"
56 #include "mindspore/core/ops/nn_ops.h"
57 #include <nlohmann/json.hpp>
58 
59 namespace mindspore {
60 namespace {
61 constexpr auto kProviderGe = "ge";
62 constexpr auto kDump = "dump";
63 constexpr auto kDumpMode = "dump_mode";
64 constexpr auto kProfiling = "profiler";
65 constexpr auto kDataFlowGraphType = "data_flow";
66 constexpr auto kCustomInputSize = 2;
67 constexpr auto kGraphKernelParam = "graph_kernel_param";
68 constexpr auto kUnkonwnSessionId = -1;
69 constexpr auto kRefModeNone = "none";
70 constexpr auto kRefModeVariable = "variable";
71 constexpr auto kRefModeAll = "all";
72 constexpr float kNumMicrosecondToMillisecond = 1000.0;
73 constexpr size_t kAlignRefData = 32;
74 
ALIGN_UP_REF_DATA(size_t size)75 size_t ALIGN_UP_REF_DATA(size_t size) {
76   return ((size + kMemAlignSize + kAlignRefData - 1) / kMemAlignSize) * kMemAlignSize;
77 }
78 
79 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
ParseGraphKernelConfigs(const ConfigInfos & maps)80 std::shared_ptr<ConverterPara> ParseGraphKernelConfigs(const ConfigInfos &maps) {
81   if (maps.find(kGraphKernelParam) == maps.end()) {
82     return nullptr;
83   }
84   auto param = std::make_shared<ConverterPara>();
85   const auto &gk_map = maps.at(kGraphKernelParam);
86   std::stringstream oss;
87   for (const auto &item : gk_map) {
88     oss << "--" << item.first << "=" << item.second << " ";
89   }
90   param->device = GetSocVersion();
91   param->graphKernelParam.graph_kernel_flags = oss.str();
92   return param;
93 }
94 #endif
95 
GenExampleGraph(const std::string & name)96 transform::DfGraphPtr GenExampleGraph(const std::string &name) {
97   MS_LOG(INFO) << "Gen fake graph name is " << name;
98   auto graph = std::make_shared<transform::DfGraph>(name);
99   auto shape_data = std::vector<int64_t>({1, 1, 1, 1});
100   transform::GeTensorDesc desc_data(ge::Shape(shape_data), ge::FORMAT_ND, ge::DT_FLOAT16);
101   auto data = ge::op::Data("data");
102   data.set_attr_index(0);
103   data.update_input_desc_x(desc_data);
104   data.update_output_desc_y(desc_data);
105   auto add = ge::op::Add("add").set_input_x1(data).set_input_x2(data);
106   std::vector<transform::Operator> inputs{data};
107   std::vector<transform::Operator> outputs{add};
108   graph->SetInputs(inputs);
109   graph->SetOutputs(outputs);
110   return graph;
111 }
112 
UpdateOmCacheIdxFile(const std::string & idx_file_name)113 bool UpdateOmCacheIdxFile(const std::string &idx_file_name) {
114   std::ifstream ifs(idx_file_name);
115   if (!ifs.good() || !ifs.is_open()) {
116     MS_LOG(INFO) << "model cache idx json not exists, idx file: " << idx_file_name << ", skip create small ge graph";
117     return false;
118   }
119   nlohmann::json dump_cfg_json;
120   try {
121     dump_cfg_json = nlohmann::json::parse(ifs);
122     const std::string cache_file_list = "cache_file_list";
123     const std::string var_desc_file_name = "var_desc_file_name";
124     auto &cache_file_config = dump_cfg_json[cache_file_list];
125     if (cache_file_config == nullptr || cache_file_config[0] == nullptr) {
126       MS_LOG(WARNING) << "model cache idx json content invalid, idx file: " << idx_file_name
127                       << ", skip create small ge graph";
128       return false;
129     }
130     auto &config = cache_file_config[0];
131     if (config[var_desc_file_name] != nullptr) {
132       config.erase(var_desc_file_name);
133       auto new_json_str = dump_cfg_json.dump(4);
134       ifs.close();
135       std::ofstream ofs(idx_file_name, std::ios::out);
136       if (!ofs.is_open()) {
137         MS_LOG(WARNING) << "Failed to open model cache idx file for write, idx file: " << idx_file_name
138                         << ", skip create small ge graph";
139         return false;
140       }
141       ofs << new_json_str;
142       ofs.close();
143 #ifndef _MSC_VER
144       chmod(idx_file_name.c_str(), S_IRUSR);
145 #endif
146       MS_LOG(INFO) << "Erase option " << var_desc_file_name;
147     }
148     return true;
149   } catch (const std::exception &error) {
150     MS_LOG(WARNING) << "parse model cache idx json failed, idx file: " << idx_file_name
151                     << ", skip create small ge graph";
152     return false;
153   }
154 }
155 }  // namespace
156 
157 std::atomic_uint32_t GeGraphExecutor::global_graph_idx_ = 0;
GetNextGraphIdx()158 uint32_t GeGraphExecutor::GetNextGraphIdx() { return global_graph_idx_++; }
GetDataFlowGraph(const FuncGraphPtr & anf_graph,const std::map<std::string,std::string> & ge_options)159 transform::DfGraphPtr GetDataFlowGraph(const FuncGraphPtr &anf_graph,
160                                        const std::map<std::string, std::string> &ge_options) {
161   MS_EXCEPTION_IF_NULL(anf_graph);
162   auto return_node = anf_graph->get_return();
163   MS_EXCEPTION_IF_NULL(return_node);
164   auto nodes = anf_graph->TopoSort(return_node);
165   auto itr = std::find_if(nodes.begin(), nodes.end(), [&](const AnfNodePtr &node) {
166     return node != nullptr && node->isa<CNode>() && opt::CheckPrimitiveType(node, prim::kPrimCustom);
167   });
168   if (itr == nodes.end()) {
169     MS_LOG(ERROR) << "The dataflow graph is invalid.";
170     return nullptr;
171   }
172   auto custom_cnode = (*itr)->cast<CNodePtr>();
173   MS_EXCEPTION_IF_NULL(custom_cnode);
174   if (custom_cnode->size() != kCustomInputSize) {
175     MS_LOG(ERROR) << "The input of dataflow custom node is not 2.";
176     return nullptr;
177   }
178   auto tensor = FuncGraphUtils::GetConstNodeValue(custom_cnode->input(1));
179   MS_EXCEPTION_IF_NULL(tensor);
180   auto data = tensor->data_c();
181   MS_EXCEPTION_IF_NULL(data);
182   auto flow_graph = reinterpret_cast<ge::dflow::FlowGraph *>(data);
183   MS_EXCEPTION_IF_NULL(flow_graph);
184   auto df_graph = std::make_shared<transform::DfGraph>(flow_graph->ToGeGraph());
185   return df_graph;
186 }
187 
~GeGraphExecutor()188 GeGraphExecutor::~GeGraphExecutor() {
189   if (ge_session_) {
190     for (auto graph_id : init_graph_id_list_) {
191       ge_session_->RemoveGraph(graph_id);
192     }
193     for (auto graph_id : compute_graph_id_list_) {
194       ge_session_->RemoveGraph(graph_id);
195       auto session_context = GeSessionManager::GetGeSessionContext(session_id_);
196       if (session_context != nullptr) {
197         (void)session_context->feature_graph_ids.erase(graph_id);
198       }
199     }
200     ge_session_ = nullptr;
201     GeSessionManager::TryReleaseGeSessionContext(session_id_);
202     enable_update_weight_ = false;
203     update_weight_ptr_ = nullptr;
204   }
205 }
206 
SetGeTensorShape(GeTensor * ge_tensor,ShapeVector shape)207 bool GeGraphExecutor::SetGeTensorShape(GeTensor *ge_tensor, ShapeVector shape) {
208   auto ge_desc = ge_tensor->GetTensorDesc();
209   ge::Shape new_ge_shape(shape);
210   ge_desc.Update(new_ge_shape);
211   ge_desc.SetOriginShape(new_ge_shape);
212   ge_tensor->SetTensorDesc(ge_desc);
213   MS_LOG(INFO) << "In SetGeTensorShape update ge shape to :" << shape;
214   return true;
215 }
216 
InitInputDeviceTensor(const FuncGraphPtr & anf_graph)217 bool GeGraphExecutor::InitInputDeviceTensor(const FuncGraphPtr &anf_graph) {
218   MS_LOG(INFO) << "Call InitInputDeviceTensor start.";
219   auto inputs = anf_graph->get_inputs();
220   inputs_buffer_infos_.resize(inputs.size());
221   for (size_t i = 0; i < inputs.size(); i++) {
222     auto &input_info = inputs_buffer_infos_[i];
223     auto shape = FuncGraphUtils::GetTensorShape({inputs[i], 0});
224 
225     /* set max_batch size and max_seq_len for dyn shape */
226     std::vector<int64_t> new_shape;
227     for (size_t j = 0; j < shape.size(); j++) {
228       if (shape[j] == abstract::Shape::kShapeDimAny) {
229         new_shape.push_back(dyn_kv_cache_info_.max_seq_len_size);
230       } else {
231         new_shape.push_back(shape[j]);
232       }
233     }
234 
235     MS_LOG(INFO) << "Init input_" << i << " buffer for ge, change shape: " << shape << " -> " << new_shape;
236     auto dtype = static_cast<TypeId>(FuncGraphUtils::GetTensorDataType({inputs[i], 0}));
237     if (!InitInOutDeviceBuffer("Input " + std::to_string(i), new_shape, dtype, &input_info)) {
238       return false;
239     }
240   }
241   return true;
242 }
243 
InitOutputDeviceTensor(const FuncGraphPtr & anf_graph,uint32_t graph_id)244 bool GeGraphExecutor::InitOutputDeviceTensor(const FuncGraphPtr &anf_graph, uint32_t graph_id) {
245   MS_LOG(INFO) << "Call GE GetCompiledGraphSummary start, graph id " << graph_id;
246   auto graph_summary = ge_session_->GetCompiledGraphSummary(graph_id);
247   if (graph_summary == nullptr) {
248     MS_LOG(ERROR) << "Failed to call GE GetCompiledGraphSummary, graph id " << graph_id
249                   << ", error: " << ge::GEGetErrorMsg();
250     return false;
251   }
252   MS_LOG(INFO) << "Call GE GetCompiledGraphSummary end, graph id " << graph_id;
253   dyn_kv_cache_info_.is_ge_graph_static_ = graph_summary->IsStatic();
254   MS_LOG(INFO) << "GE graph is static :" << dyn_kv_cache_info_.is_ge_graph_static_ << ", graph id: " << graph_id;
255   std::vector<AnfWithOutIndex> outputs;
256   if (!FuncGraphUtils::GetFuncGraphOutputs(anf_graph, &outputs)) {
257     MS_LOG(ERROR) << "Failed to get func graph outputs";
258     return false;
259   }
260   outputs_buffer_infos_.resize(outputs.size());
261   if (dyn_kv_cache_info_.is_ge_graph_static_) {
262     std::vector<::ge::Shape> ge_shapes;
263     auto ge_status = graph_summary->GetOutputShapes(ge_shapes);
264     if (ge_status != ge::GRAPH_SUCCESS) {
265       MS_LOG(ERROR) << "Failed to call GetOutputShapes, status: " << ge_status;
266       return false;
267     }
268     if (outputs.size() != ge_shapes.size()) {
269       MS_LOG(ERROR) << "Output count got from graph " << outputs.size() << " != that " << ge_shapes.size()
270                     << " got from GE";
271       return false;
272     }
273     for (size_t i = 0; i < outputs.size(); i++) {
274       auto &output_info = outputs_buffer_infos_[i];
275       auto shape = ge_shapes[i].GetDims();
276       auto dtype = static_cast<TypeId>(FuncGraphUtils::GetTensorDataType(outputs[i]));
277       if (!InitInOutDeviceBuffer("Output " + std::to_string(i), shape, dtype, &output_info)) {
278         return false;
279       }
280     }
281   }
282   return true;
283 }
284 
SetRefShape(std::vector<int64_t> * ref_shape,bool dyn,std::string tensor_name)285 void GeGraphExecutor::SetRefShape(std::vector<int64_t> *ref_shape, bool dyn, std::string tensor_name) {
286   if (!dyn_kv_cache_info_.dynamic_kv_cache) {
287     return;
288   }
289   size_t b_index = kDim0;
290   size_t s_index = kDim2;
291   if (dyn_kv_cache_info_.kv_cache_layout == lite::kKVCacheLayoutBSH) {
292     s_index = kDim1;
293   }
294   if (dyn) {
295     if (dyn_kv_cache_info_.batch_size_dyn) {
296       (*ref_shape)[b_index] = abstract::Shape::kShapeDimAny;
297       MS_LOG(INFO) << "for " << tensor_name << " update batch size to dyn(-1) for ge_option.";
298     }
299     if (dyn_kv_cache_info_.seq_length_dyn) {
300       (*ref_shape)[s_index] = abstract::Shape::kShapeDimAny;
301       MS_LOG(INFO) << "for " << tensor_name << " update seq length size to dyn(-1) for ge_option.";
302     }
303   } else {
304     if (dyn_kv_cache_info_.batch_size_dyn) {
305       (*ref_shape)[b_index] = dyn_kv_cache_info_.real_batch_size;
306       MS_LOG(INFO) << "for " << tensor_name << " update batch size to " << dyn_kv_cache_info_.real_batch_size
307                    << " for ge_option.";
308     }
309     if (dyn_kv_cache_info_.seq_length_dyn) {
310       (*ref_shape)[s_index] = dyn_kv_cache_info_.real_seq_len_size;
311       MS_LOG(INFO) << "for " << tensor_name << " update seq length size to " << dyn_kv_cache_info_.real_seq_len_size
312                    << " for ge_option.";
313     }
314   }
315 }
316 
UpdateOutputShapeInfo(std::vector<::ge::Tensor> * ge_outputs)317 void GeGraphExecutor::UpdateOutputShapeInfo(std::vector<::ge::Tensor> *ge_outputs) {
318   MS_LOG(INFO) << "Update output dtype and shape.";
319   for (size_t i = 0; i < outputs_buffer_infos_.size(); i++) {
320     auto &output_info = outputs_buffer_infos_[i];
321     auto &ge_output = ge_outputs->at(i);
322     auto ge_tensor_desc = ge_output.GetTensorDesc();
323     output_info.shape = transform::TransformUtil::ConvertGeShape(ge_tensor_desc.GetShape());
324     output_info.dtype = transform::TransformUtil::ConvertGeDataType(ge_tensor_desc.GetDataType());
325     output_info.max_size = SizeOf(output_info.shape) * GetDataTypeSize(output_info.dtype);
326     auto out_device = ge_output.GetData();
327     if (dyn_kv_cache_info_.is_ge_graph_static_ && out_device != output_info.device_addr) {
328       MS_LOG(WARNING) << "GE output device address not equal malloc device memory when graph is static";
329     }
330     output_info.device_addr = out_device;
331     MS_LOG(INFO) << "Update output_" << i << " dtype: " << output_info.dtype << ", shape: " << output_info.shape;
332   }
333   return;
334 }
335 
SetDynamicKVCache(const FuncGraphPtr & func_graph)336 bool GeGraphExecutor::SetDynamicKVCache(const FuncGraphPtr &func_graph) {
337   auto graph_inputs = func_graph->get_inputs();
338   auto has_dynamic_input = std::any_of(graph_inputs.begin(), graph_inputs.end(), [](const AnfNodePtr &input) {
339     auto shape = FuncGraphUtils::GetTensorShape({input, 0});
340     return std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; });
341   });
342   if (!has_dynamic_input) {
343     MS_LOG(INFO) << "Not detect dynamic input in graph";
344     return true;
345   }
346   auto nodes = func_graph->TopoSort(func_graph->get_return());
347   if (nodes.empty()) {
348     MS_LOG(WARNING) << "There are no nodes in the graph";
349     return true;
350   }
351   constexpr size_t kv_index = 2;  // primitive, kv cache, kv
352   for (auto &node : nodes) {
353     auto cnode = node->cast<CNodePtr>();
354     if (!cnode || !IsPrimitiveCNode(cnode, prim::kPrimPromptKVCache)) {
355       continue;
356     }
357     auto inputs = cnode->inputs();
358     if (inputs.size() <= kv_index) {
359       MS_LOG(WARNING) << "PrimPromptKVCache " << cnode->fullname_with_scope() << " input size " << inputs.size() - 1
360                       << " <= kv index " << kv_index - 1;
361       continue;
362     }
363     auto kv_input = inputs[kv_index];
364     if (kv_input == nullptr) {
365       MS_LOG(WARNING) << "PrimPromptKVCache " << cnode->fullname_with_scope() << " kv input is nullptr";
366       continue;
367     }
368     if (!IsPrimitiveCNode(kv_input, prim::kPrimPadV3)) {
369       dyn_kv_cache_info_.dynamic_kv_cache = true;
370       dyn_kv_cache_info_.seq_length_dyn = true;
371       auto kv_shape = FuncGraphUtils::GetTensorShape({kv_input, 0});
372       if (kv_shape.size() == kShape4dDims) {
373         dyn_kv_cache_info_.kv_cache_layout = lite::kKVCacheLayoutBNSD;
374       } else if (kv_shape.size() == kShape3dDims) {
375         dyn_kv_cache_info_.kv_cache_layout = lite::kKVCacheLayoutBSH;
376       } else {
377         MS_LOG(ERROR) << "Expect RefData shape to be BNSD or BSH when dynamic kv cache is enable, but got " << kv_shape;
378         return false;
379       }
380     }
381     break;
382   }
383   MS_LOG(INFO) << "set dyn kv info dynamic_kv_cache : " << dyn_kv_cache_info_.dynamic_kv_cache;
384   MS_LOG(INFO) << "set dyn kv info seq_length_dyn : " << dyn_kv_cache_info_.seq_length_dyn;
385   return true;
386 }
387 
CheckRefDataInfo()388 bool GeGraphExecutor::CheckRefDataInfo() {
389   if (!dyn_kv_cache_info_.dynamic_kv_cache) {
390     return true;
391   }
392   auto &ref_shape = ref_data_infos_.front().shape;
393   for (size_t i = 0; i < ref_data_infos_.size(); i++) {
394     auto &ref_data_info = ref_data_infos_[i];
395     auto &para_name = ref_data_info.name;
396     if (dyn_kv_cache_info_.kv_cache_layout == lite::kKVCacheLayoutBSH) {
397       if (ref_data_info.shape.size() != kShape3dDims) {
398         MS_LOG(ERROR) << "KVCache shape size is not " << kShape3dDims << ", while KVCache layout is "
399                       << dyn_kv_cache_info_.kv_cache_layout << ", KVCache param " << para_name << ", shape "
400                       << ref_data_info.shape;
401         return false;
402       }
403     } else if (dyn_kv_cache_info_.kv_cache_layout == lite::kKVCacheLayoutBNSD) {
404       if (ref_data_info.shape.size() != kShape4dDims) {
405         MS_LOG(ERROR) << "KVCache shape size is not " << kShape4dDims << ", while KVCache layout is "
406                       << dyn_kv_cache_info_.kv_cache_layout << ", KVCache param " << para_name << ", shape "
407                       << ref_data_info.shape;
408         return false;
409       }
410     } else {
411       MS_LOG(ERROR) << "Unsupported KVCache layout " << dyn_kv_cache_info_.kv_cache_layout;
412       return false;
413     }
414     if (ref_shape != ref_data_info.shape) {
415       MS_LOG(ERROR) << "KVCache shape " << ref_data_info.shape << " of " << para_name << " != KVCache shape "
416                     << ref_shape << " of " << ref_data_infos_.front().name;
417       return false;
418     }
419   }
420   return true;
421 }
422 
InitMaxShapeParam()423 bool GeGraphExecutor::InitMaxShapeParam() {
424   if (ref_data_infos_.empty()) {
425     MS_LOG(INFO) << "RefData count is empty";
426     return true;
427   }
428   if (!CheckRefDataInfo()) {
429     return false;
430   }
431   auto &ref_shape = ref_data_infos_.front().shape;
432   size_t b_index = kDim0;
433   size_t s_index = kDim2;
434   if (ref_shape.size() == kShape3dDims) {  // BSH
435     s_index = kDim1;
436   } else if (ref_shape.size() == kShape4dDims) {  // BNSD
437     s_index = kDim2;
438   } else {
439     MS_LOG(WARNING) << "RefData dim count is unexpected, shape " << ref_shape << ", name "
440                     << ref_data_infos_.front().name;
441     return true;
442   }
443   std::string max_batch_size;
444   if (GetConfigOption("ascend_context", "max_batch_size", &max_batch_size)) {
445     MS_LOG(INFO) << "Get max batch size from config file, ascend_context, max_batch_size";
446     dyn_kv_cache_info_.max_batch_size = std::stoi(max_batch_size);
447   } else {
448     MS_LOG(INFO) << "Get max batch size from ref data shape : " << ref_shape;
449     dyn_kv_cache_info_.max_batch_size = ref_shape[b_index];
450   }
451 
452   std::string max_seq_length;
453   if (GetConfigOption("ascend_context", "max_seq_length", &max_seq_length)) {
454     MS_LOG(INFO) << "Get max seq length from config file, ascend_context, max_seq_length";
455     dyn_kv_cache_info_.max_seq_len_size = std::stoi(max_seq_length);
456   } else {
457     MS_LOG(INFO) << "Get max seq length from ref data shape : " << ref_shape;
458     dyn_kv_cache_info_.max_seq_len_size = ref_shape[s_index];
459   }
460 
461   MS_LOG(INFO) << "set dynamic max shape, max batch size : " << dyn_kv_cache_info_.max_batch_size
462                << ", max seq length: " << dyn_kv_cache_info_.max_seq_len_size;
463   return true;
464 }
465 
InitRealShapeParam(const std::vector<tensor::Tensor> & inputs)466 bool GeGraphExecutor::InitRealShapeParam(const std::vector<tensor::Tensor> &inputs) {
467   if (!dyn_kv_cache_info_.dynamic_kv_cache) {
468     return true;
469   }
470   auto input_0_shape = inputs[0].shape_c();
471   if (input_0_shape.size() != kShape2dDims) {
472     MS_LOG(ERROR) << "Expected input 0 shape to be [bs, seq_length], but got " << input_0_shape;
473     return false;
474   }
475   dyn_kv_cache_info_.real_batch_size = input_0_shape.at(Index0);
476   MS_LOG(INFO) << "Real batch size : " << dyn_kv_cache_info_.real_batch_size;
477   dyn_kv_cache_info_.real_seq_len_size = input_0_shape.at(Index1);
478   MS_LOG(INFO) << "Real seq length size : " << dyn_kv_cache_info_.real_seq_len_size;
479   return true;
480 }
481 
GetConfigOption(const std::string & section_name,const std::string & option_name,std::string * option_val)482 bool GeGraphExecutor::GetConfigOption(const std::string &section_name, const std::string &option_name,
483                                       std::string *option_val) {
484   if (option_val == nullptr) {
485     MS_LOG(ERROR) << "Input argument option_val is nullptr";
486     return false;
487   }
488   auto config_it = config_infos_.find(section_name);
489   if (config_it == config_infos_.end()) {
490     return false;
491   }
492   auto &options = config_it->second;
493   auto option_it = options.find(option_name);
494   if (option_it == options.end()) {
495     return false;
496   }
497   *option_val = option_it->second;
498   return true;
499 }
500 
GetRankID() const501 uint32_t GeGraphExecutor::GetRankID() const {
502   auto ascend_info = GeUtils::GetAscendDeviceInfo(context_);
503   if (ascend_info == nullptr) {
504     MS_LOG(ERROR) << "Can not find ascend device context.";
505     return 0;
506   }
507   return ascend_info->GetRankID();
508 }
509 
GetDeviceID() const510 uint32_t GeGraphExecutor::GetDeviceID() const {
511   auto ascend_info = GeUtils::GetAscendDeviceInfo(context_);
512   if (ascend_info == nullptr) {
513     MS_LOG(ERROR) << "Can not find ascend device context.";
514     return 0;
515   }
516   return ascend_info->GetDeviceID();
517 }
518 
Init()519 bool GeGraphExecutor::Init() {
520   ge_global_context_ = GeDeviceContext::InitGlobalContext(context_, config_infos_);
521   if (ge_global_context_ == nullptr) {
522     MS_LOG(ERROR) << "Failed to Init global context";
523     return false;
524   }
525   if (!InitRefModeConfig()) {
526     return false;
527   }
528   std::string model_cache_mode;
529   (void)GetConfigOption(lite::kAscendContextSection, lite::kModelCacheMode, &model_cache_mode);
530   if (!model_cache_mode.empty()) {
531     cache_mode_ = model_cache_mode;
532     MS_LOG(INFO) << "Set set model cache mode " << model_cache_mode;
533   }
534   std::string variable_weights_list;
535   (void)GetConfigOption(lite::kAscendContextSection, "variable_weights_list", &variable_weights_list);
536   if (!variable_weights_list.empty()) {
537     update_weight_ptr_ = std::make_shared<UpdateWeight>();
538     if (update_weight_ptr_ == nullptr) {
539       MS_LOG(ERROR) << "init update weight ptr failed.";
540       return false;
541     }
542     if (!update_weight_ptr_->ParseUpdateWeightConfig(variable_weights_list)) {
543       MS_LOG(ERROR) << "ParseUpdateWeightConfig failed.";
544       update_weight_ptr_ = nullptr;
545       return false;
546     }
547     enable_update_weight_ = true;
548   }
549   return true;
550 }
551 
InitRefModeConfig()552 bool GeGraphExecutor::InitRefModeConfig() {
553   std::string ref_mode;
554   (void)GetConfigOption(lite::kAscendContextSection, lite::kParameterAsRefData, &ref_mode);
555   if (!ref_mode.empty()) {
556     ref_mode = lite::StringTolower(ref_mode);
557     if (ref_mode != kRefModeNone && ref_mode != kRefModeVariable && ref_mode != kRefModeAll) {
558       MS_LOG(ERROR) << "Only " << kRefModeNone << ", " << kRefModeVariable << " or " << kRefModeAll
559                     << " is supported for " << lite::kParameterAsRefData << ", but got " << ref_mode;
560       return false;
561     }
562     if (ref_mode == kRefModeAll) {
563       ref_mode_flag_ = transform::RefModeFlag::kRefModeAll;
564     } else if (ref_mode == kRefModeVariable) {
565       ref_mode_flag_ = transform::RefModeFlag::kRefModeVariable;
566     } else {
567       ref_mode_flag_ = transform::RefModeFlag::kRefModeNone;
568     }
569     MS_LOG(INFO) << "Set parameter ref mode " << ref_mode;
570   } else {
571     ref_mode_flag_ = transform::RefModeFlag::kRefModeNone;
572   }
573   return true;
574 }
575 
GetGeSessionOptions(std::map<std::string,std::string> * ge_options_ptr)576 void GeGraphExecutor::GetGeSessionOptions(std::map<std::string, std::string> *ge_options_ptr) {
577   MS_EXCEPTION_IF_NULL(ge_options_ptr);
578   auto &ge_options = *ge_options_ptr;
579   ge_options["ge.trainFlag"] = "0";
580   ge_options["ge.enablePrintOpPass"] = "0";
581   ge_options["ge.exec.device_id"] = std::to_string(GetDeviceID());
582   ge_options["ge.exec.staticMemoryPolicy"] = "2";
583   if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
584     ge_options["ge.constLifecycle"] = "graph";
585   }
586   auto config_it = config_infos_.find(lite::kGeSessionOptionsSection);
587   if (config_it != config_infos_.end()) {
588     for (auto &item : config_it->second) {
589       ge_options[item.first] = item.second;
590       MS_LOG(INFO) << "Set ge session option " << item.first << " to " << item.second;
591     }
592   }
593   config_it = config_infos_.find(lite::kAscendContextSection);
594   if (config_it != config_infos_.end()) {
595     GetGeSessionOptionsFromAscendContext(config_it->second, ge_options_ptr);
596   }
597 }
598 
SetModelCacheDir(std::map<std::string,std::string> * session_options_ptr)599 bool GeGraphExecutor::SetModelCacheDir(std::map<std::string, std::string> *session_options_ptr) {
600   auto &ge_options = *session_options_ptr;
601   auto build_cache_dir = "model_build_cache_" + std::to_string(GetRankID());
602   if (lite::CreateDir(build_cache_dir) != RET_OK) {
603     MS_LOG(ERROR) << "Failed to create build cache dir " << build_cache_dir;
604     return false;
605   }
606   ge_options[kGeGraphCompilerCacheDir] = build_cache_dir;
607   MS_LOG(INFO) << "Update session attr " << kGeGraphCompilerCacheDir << " to " << build_cache_dir;
608   return true;
609 }
610 
SetOfflineBuildModelCacheDir(std::map<std::string,std::string> * session_options_ptr)611 bool GeGraphExecutor::SetOfflineBuildModelCacheDir(std::map<std::string, std::string> *session_options_ptr) {
612   std::string build_cache_dir;
613   auto &ge_options = *session_options_ptr;
614   bool build_cache_enabled = false;
615   std::string output_file;
616   (void)GetConfigOption(lite::kConverterParams, lite::kConverterOutputFile, &output_file);
617   std::string output_dir = "./";
618   if (output_file.find("/") != std::string::npos) {
619     output_dir = output_file.substr(0, output_file.rfind("/") + 1);
620   }
621   session_id_ = GetSessionId();
622   auto ge_session_context = GeSessionManager::GetGeSessionContext(session_id_);
623   if (ge_session_context) {
624     const auto &last_ge_options = ge_session_context->session_options;
625     if (auto it = last_ge_options.find(kGeGraphCompilerCacheDir); it != last_ge_options.end()) {
626       build_cache_dir = it->second;
627       build_cache_enabled = true;
628     }
629   }
630   if (!build_cache_enabled) {
631     std::string mindir_postfix = ".mindir";
632     auto ops = output_file.find(mindir_postfix);
633     if (ops != std::string::npos && ops == output_file.size() - mindir_postfix.size()) {
634       output_file = output_file.substr(0, output_file.size() - mindir_postfix.size());
635     }
636     if (output_file.empty()) {
637       MS_LOG(ERROR) << "Converter output file cannot be empty";
638       return false;
639     }
640     build_cache_dir = output_file + "_variables";
641   }
642   if (lite::CreateDir(build_cache_dir) != RET_OK) {
643     MS_LOG(ERROR) << "Failed to create build cache dir " << build_cache_dir;
644     return false;
645   }
646   ge_options[kGeGraphCompilerCacheDir] = build_cache_dir;
647   MS_LOG(INFO) << "Update session attr " << kGeGraphCompilerCacheDir << " to " << build_cache_dir;
648   if (build_cache_dir.find(output_dir) == 0) {
649     build_cache_relative_dir_ = "./" + build_cache_dir.substr(output_dir.size());
650   }
651   return true;
652 }
653 
GetGeSessionOptionsFromAscendContext(const std::map<std::string,std::string> & config,std::map<std::string,std::string> * ge_options_ptr)654 void GeGraphExecutor::GetGeSessionOptionsFromAscendContext(const std::map<std::string, std::string> &config,
655                                                            std::map<std::string, std::string> *ge_options_ptr) {
656   MS_EXCEPTION_IF_NULL(ge_options_ptr);
657   auto &ge_options = *ge_options_ptr;
658   auto option_id = config.find(lite::kDumpPathKey);
659   if (option_id != config.end()) {
660     auto dump_path = option_id->second;
661     auto real_path = lite::RealPath(dump_path.c_str());
662     std::ifstream ifs(real_path);
663     if (!ifs.good() || !ifs.is_open()) {
664       MS_LOG(EXCEPTION) << "The dump config file: " << real_path << " is not exit or open failed.";
665     }
666     nlohmann::json dump_cfg_json;
667     try {
668       dump_cfg_json = nlohmann::json::parse(ifs);
669     } catch (const nlohmann::json::parse_error &error) {
670       MS_LOG(EXCEPTION) << "parse json failed, please check the file: " << real_path;
671     }
672     if (dump_cfg_json[kDump] != nullptr && dump_cfg_json[kDump][kDumpMode] != nullptr) {
673       ge_options["ge.exec.enableDump"] = "1";
674       ge_options["ge.exec.dumpMode"] = dump_cfg_json[kDump][kDumpMode].get<std::string>();
675     }
676   }
677   option_id = config.find(lite::kProfilingPathKey);
678   if (option_id != config.end()) {
679     auto profiling_path = option_id->second;
680     auto real_path = lite::RealPath(profiling_path.c_str());
681     std::ifstream ifs(real_path);
682     if (!ifs.good() || !ifs.is_open()) {
683       MS_LOG(EXCEPTION) << "The profiling_path config file: " << real_path << " is not exit or open failed.";
684     }
685     nlohmann::json profiling_cfg_json;
686     try {
687       profiling_cfg_json = nlohmann::json::parse(ifs);
688     } catch (const nlohmann::json::parse_error &error) {
689       MS_LOG(EXCEPTION) << "parse json failed, please check the file: " << real_path;
690     }
691     if (profiling_cfg_json[kProfiling] != nullptr) {
692       ge_options["ge.exec.profilingMode"] = "1";
693       ge_options["ge.exec.profilingOptions"] = profiling_cfg_json[kProfiling].dump();
694     }
695   }
696   option_id = config.find(lite::kGeVariableMemoryMaxSize);
697   if (option_id != config.end()) {
698     ge_options["ge.variableMemoryMaxSize"] = option_id->second;
699   }
700   option_id = config.find(lite::kGeGraphMemoryMaxSize);
701   if (option_id != config.end()) {
702     ge_options["ge.graphMemoryMaxSize"] = option_id->second;
703   }
704   option_id = config.find(lite::kGraphCompilerCacheDirKey);
705   if (option_id != config.end()) {
706     ge_options[kGeGraphCompilerCacheDir] = option_id->second;
707   }
708 }
709 
GetGeGraphOptions(const FuncGraphPtr & anf_graph,std::map<std::string,std::string> * ge_options_ptr)710 void GeGraphExecutor::GetGeGraphOptions(const FuncGraphPtr &anf_graph,
711                                         std::map<std::string, std::string> *ge_options_ptr) {
712   MS_EXCEPTION_IF_NULL(anf_graph);
713   MS_EXCEPTION_IF_NULL(ge_options_ptr);
714   auto &ge_options = *ge_options_ptr;
715   auto ascend_device_info = GeUtils::GetAscendDeviceInfo(context_);
716   if (ascend_device_info == nullptr) {
717     MS_LOG(EXCEPTION) << "Failed to get graph session options, can not find ascend device context.";
718   }
719   uint32_t rank_id = ascend_device_info->GetRankID();
720   graph_name_ = std::to_string(rank_id) + "_" + std::to_string(global_graph_idx_) + "_" + anf_graph->ToString();
721   for (auto &c : graph_name_) {
722     if (c == '.') {
723       c = '_';
724     }
725   }
726   ge_options[kGeGraphKey] = graph_name_;
727   auto config_it = config_infos_.find(lite::kGeGraphOptionsSection);
728   if (config_it != config_infos_.end()) {
729     for (auto &item : config_it->second) {
730       ge_options[item.first] = item.second;
731       MS_LOG(INFO) << "Set ge graph option " << item.first << " to " << item.second;
732     }
733   }
734 
735   auto precision_mode = ascend_device_info->GetPrecisionMode();
736   if (!precision_mode.empty()) {
737     ge_options["ge.exec.precision_mode"] = TransforPrecisionToAcl(precision_mode);
738   }
739   config_it = config_infos_.find(lite::kAscendContextSection);
740   if (config_it == config_infos_.end()) {
741     return;
742   }
743   auto config = config_it->second;
744   auto option_id = config.find(lite::kModifyMixList);
745   if (option_id != config.end()) {
746     ge_options["ge.exec.modify_mixlist"] = option_id->second;
747   }
748 }
749 
GetSessionId()750 int64_t GeGraphExecutor::GetSessionId() {
751   std::string inner_group_id;
752   (void)GetConfigOption(lite::kLiteInnerGroupSection, lite::kLiteInnerGroupId, &inner_group_id);
753   if (inner_group_id.empty()) {
754     return kUnkonwnSessionId;
755   }
756   int64_t session_id = kUnkonwnSessionId;
757   if (!lite::ConvertStrToInt(inner_group_id, &session_id)) {
758     MS_LOG(WARNING) << "Failed to parse session_id " << inner_group_id << " to int64_t";
759     return kUnkonwnSessionId;
760   }
761   return session_id;
762 }
763 
CreateSession(const std::map<std::string,std::string> & extra_options)764 bool GeGraphExecutor::CreateSession(const std::map<std::string, std::string> &extra_options) {
765   if (ge_session_ != nullptr) {
766     MS_LOG(INFO) << "Ge session has already been created";
767     return true;
768   }
769   session_id_ = GetSessionId();
770   (void)setenv("GE_TRAIN", "0", 1);
771   std::map<std::string, std::string> session_options = extra_options;
772   GetGeSessionOptions(&session_options);
773   if (auto option_id = session_options.find(kGeGraphCompilerCacheDir); option_id != session_options.end()) {
774     build_cache_dir_ = option_id->second;
775   }
776   session_options_ = session_options;
777   ge_session_ = GeSessionManager::CreateGeSession(session_id_, session_options);
778   if (ge_session_ == nullptr) {
779     MS_LOG(ERROR) << "Failed to create ge session";
780     return false;
781   }
782   return true;
783 }
784 
AddGraph(const transform::DfGraphPtr & graph,const std::map<std::string,std::string> & options,uint32_t * graph_id_ret)785 bool GeGraphExecutor::AddGraph(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &options,
786                                uint32_t *graph_id_ret) {
787   if (ge_session_ == nullptr) {
788     MS_LOG(ERROR) << "Failed to add graph, ge session cannot be nullptr";
789     return false;
790   }
791   auto graph_id = GetNextGraphIdx();
792   for (auto &option : options) {
793     MS_LOG(INFO) << "GE Graph " << graph_id << " option " << option.first << " = " << option.second;
794   }
795   auto ge_status = ge_session_->AddGraph(static_cast<uint32_t>(graph_id), *(graph), options);
796   if (ge_status != ge::GRAPH_SUCCESS) {
797     MS_LOG(ERROR) << "Call GE AddGraph Failed: " << ge::GEGetErrorMsg();
798     return false;
799   }
800   *graph_id_ret = graph_id;
801   return true;
802 }
803 
GetParams(const FuncGraphPtr & anf_graph,transform::TensorOrderMap * param_tensors)804 void GeGraphExecutor::GetParams(const FuncGraphPtr &anf_graph, transform::TensorOrderMap *param_tensors) {
805   MS_EXCEPTION_IF_NULL(anf_graph);
806 
807   transform::TensorOrderMap res;
808   for (auto &anf_node : anf_graph->parameters()) {
809     MS_EXCEPTION_IF_NULL(anf_node);
810     auto para = anf_node->cast<ParameterPtr>();
811     MS_EXCEPTION_IF_NULL(para);
812     if (para->has_default()) {
813       auto value = para->default_param();
814       MS_EXCEPTION_IF_NULL(value);
815       auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
816       MS_EXCEPTION_IF_NULL(tensor);
817       auto para_name = para->name();
818       res.emplace(para_name, tensor);
819     }
820   }
821   if (session_id_ != kUnkonwnSessionId) {
822     std::vector<std::string> graph_params;
823     std::transform(res.begin(), res.end(), std::back_inserter(graph_params),
824                    [](const auto &item) { return item.first; });
825     auto new_params_set = GeSessionManager::UpdateSessionVariables(session_id_, graph_params);
826     for (auto &item : res) {
827       // parameters not in new_params_set has been init by other graph
828       if (new_params_set.find(item.first) == new_params_set.end()) {
829         item.second->set_init_flag(true);
830       }
831     }
832   }
833   *param_tensors = res;
834 }
835 
UpdateGraphInputs(const FuncGraphPtr & graph)836 bool GeGraphExecutor::UpdateGraphInputs(const FuncGraphPtr &graph) {
837   std::string input_shape_str;
838   std::vector<GeDynamicShapeInfo> input_shapes;
839   if (!GeDynamicUtils::GetGraphInputShapes(context_, config_infos_, &input_shapes, &input_shape_str)) {
840     MS_LOG(ERROR) << "Failed to get input shape from AscendDeviceInfo or config file";
841     return false;
842   }
843   if (input_shapes.empty()) {
844     MS_LOG(INFO) << "Not found input shape in AscendDeviceInfo or config file";
845     return true;
846   }
847   auto inputs = graph->get_inputs();
848   if (inputs.size() != input_shapes.size()) {
849     MS_LOG(ERROR) << "FuncGraph input size " << inputs.size() << " != input size " << input_shapes.size()
850                   << " in AscendDeviceInfo or config file " << input_shapes.size();
851     return false;
852   }
853   for (size_t i = 0; i < input_shapes.size(); i++) {
854     auto node = inputs[i];
855     MS_CHECK_TRUE_RET(node != nullptr, false);
856     auto input_shape = input_shapes[i];
857     auto para = node->cast<ParameterPtr>();
858     if (para == nullptr) {
859       MS_LOG(ERROR) << "Cast input to Parameter failed";
860       return false;
861     }
862     MS_LOG(INFO) << "Func graph input_" << i << " " << para->name()
863                  << ", shape: " << FuncGraphUtils::GetTensorShape({node, 0});
864 
865     auto it = std::find_if(input_shapes.begin(), input_shapes.end(),
866                            [&para](const auto &item) { return item.name == para->name(); });
867     if (it == input_shapes.end()) {
868       MS_LOG(ERROR) << "Failed to find input " << para->name() << " in input_shape " << input_shape_str;
869       return false;
870     }
871     auto abstract = para->abstract();
872     if (abstract == nullptr) {
873       MS_LOG(ERROR) << "Get input abstract failed";
874       return false;
875     }
876     ShapeVector shape;
877     std::transform(it->shape.begin(), it->shape.end(), std::back_inserter(shape), [](auto &dim) { return dim.dim; });
878     MS_LOG(INFO) << "Update shape of input_" << i << " " << para->name() << " to " << shape;
879     abstract->set_shape(std::make_shared<abstract::Shape>(shape));
880   }
881   return true;
882 }
883 
InitRefDataList(const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_data_tensors)884 bool GeGraphExecutor::InitRefDataList(const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors) {
885   for (auto &item : ref_data_tensors) {
886     auto para_name = item.first;
887     auto &tensor = item.second;
888     RefDataInfo ref_data_info;
889     ref_data_info.name = para_name;
890     ref_data_info.shape = tensor->shape_c();
891     ref_data_info.dtype = tensor->data_type();
892     ref_data_info.host_data = item.second;
893     MS_LOG(INFO) << "Init ref data info[" << ref_data_infos_.size() << "] :" << ref_data_info.name
894                  << ", dtype:" << ref_data_info.dtype << ", shape:" << ref_data_info.shape;
895     ref_data_infos_.push_back(ref_data_info);
896   }
897   return true;
898 }
899 
InitMemoryContextManager()900 bool GeGraphExecutor::InitMemoryContextManager() {
901   auto session_context = GeSessionManager::GetGeSessionContext(session_id_);
902   if (session_context != nullptr) {
903     memory_manager_ = session_context->memory_manager.lock();
904     context_manager_ = session_context->context_manager.lock();
905   }
906   if (memory_manager_ == nullptr) {
907     memory_manager_ = std::make_shared<GeMemoryManager>();
908     if (memory_manager_ == nullptr) {
909       MS_LOG(ERROR) << "Failed to create memory manager";
910       return false;
911     }
912     if (session_context != nullptr) {
913       session_context->memory_manager = memory_manager_;
914     }
915   }
916   if (context_manager_ == nullptr) {
917     context_manager_ = std::make_shared<GeContextManager>();
918     if (context_manager_ == nullptr) {
919       MS_LOG(ERROR) << "Failed to create context manager";
920       return false;
921     }
922     if (!context_manager_->InitContext(GetDeviceID())) {
923       MS_LOG(ERROR) << "Failed to init device";
924       return false;
925     }
926     if (session_context != nullptr) {
927       session_context->context_manager = context_manager_;
928     }
929   }
930   if (!context_manager_->SetContext()) {
931     MS_LOG(ERROR) << "Failed to set ge context";
932     return false;
933   }
934   return true;
935 }
936 
InitRefDataDeviceTensor()937 bool GeGraphExecutor::InitRefDataDeviceTensor() {
938   MS_LOG(INFO) << "InitRefDataDeviceTensor start.";
939   if (ref_data_infos_.empty()) {
940     MS_LOG(INFO) << "There is not ref data, no need to init ref data device data";
941     return true;
942   }
943   std::map<std::string, RefDataInfo> session_ref_data_map;
944   auto session_context = GeSessionManager::GetGeSessionContext(session_id_);
945   if (session_context != nullptr) {
946     session_ref_data_map = session_context->ref_data_map_;
947   }
948 
949   size_t ref_data_total_size = 0;
950   std::map<std::string, tensor::TensorPtr> new_param_tensor_map;
951   for (size_t i = 0; i < ref_data_infos_.size(); i++) {
952     auto &item = ref_data_infos_[i];
953     auto tensor = item.host_data;
954     item.size = tensor->Size();
955     item.host_data = nullptr;  // release host memory
956     ShapeVector ref_data_shape = tensor->shape_c();
957     SetRefShape(&ref_data_shape, true, item.name);
958     auto desc = transform::TransformUtil::GetGeTensorDesc(ref_data_shape, tensor->data_type(), kOpFormat_NCHW);
959     if (desc == nullptr) {
960       MS_LOG(ERROR) << "Failed to get Tensor Desc";
961       return false;
962     }
963     desc->SetPlacement(::ge::kPlacementDevice);
964     auto ret = item.ge_tensor.SetTensorDesc(*desc);
965     if (ret != ACL_ERROR_NONE) {
966       MS_LOG(ERROR) << "Failed to call ge::Tensor::SetTensorDesc, ret " << ret;
967       return false;
968     }
969     if (auto ref_it = session_ref_data_map.find(item.name); ref_it != session_ref_data_map.end()) {
970       auto &org_item = ref_it->second;
971       MS_LOG(INFO) << "Find RefData " << item.name << ", shape " << org_item.shape << ", size " << org_item.size;
972       if (org_item.size != item.size) {
973         MS_LOG(ERROR) << "RefData " << item.name << " data size != the size in pre graph, current shape " << item.shape
974                       << ", size " << item.size << ", pre shape " << org_item.shape << ", pre size " << org_item.size;
975         return false;
976       }
977       auto dst_addr = ref_it->second.ge_tensor.GetData();
978       ret = item.ge_tensor.SetData(dst_addr, item.size, [](uint8_t *) -> void {});
979       if (ret != ge::GRAPH_SUCCESS) {
980         MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << item.size;
981         return false;
982       }
983     } else {
984       item.offset = ref_data_total_size;
985       ref_data_total_size += ALIGN_UP_REF_DATA(tensor->Size());
986       new_param_tensor_map[item.name] = tensor;
987     }
988   }
989   if (ref_data_total_size != 0) {
990     auto device_memory = memory_manager_->MallocDeviceMemory("RefData input", ref_data_total_size);
991     if (device_memory == nullptr) {
992       return false;
993     }
994     for (auto &item : ref_data_infos_) {
995       auto it = new_param_tensor_map.find(item.name);
996       if (it == new_param_tensor_map.end()) {
997         continue;
998       }
999       auto &tensor_val = it->second;
1000       auto dst_addr = device_memory + item.offset;
1001       if (!memory_manager_->MemcpyHost2Device(dst_addr, item.size, tensor_val->data_c(), tensor_val->Size())) {
1002         MS_LOG(ERROR) << "Failed to memory copy host data to device";
1003         return false;
1004       }
1005       auto ret = item.ge_tensor.SetData(dst_addr, item.size, [](uint8_t *) -> void {});
1006       if (ret != ge::GRAPH_SUCCESS) {
1007         MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << item.size;
1008         return false;
1009       }
1010       if (session_context != nullptr) {
1011         session_context->ref_data_map_[item.name] = item;
1012       }
1013     }
1014   }
1015   return true;
1016 }
1017 
InitInOutDeviceBuffer(const std::string & name,const ShapeVector & shape,TypeId dtype,InOutBufferInfo * buffer_info)1018 bool GeGraphExecutor::InitInOutDeviceBuffer(const std::string &name, const ShapeVector &shape, TypeId dtype,
1019                                             InOutBufferInfo *buffer_info) {
1020   auto &info = *buffer_info;
1021   auto desc = transform::TransformUtil::GetGeTensorDesc(shape, dtype, kOpFormat_NCHW);
1022   if (desc == nullptr) {
1023     MS_LOG(ERROR) << "Failed to get Tensor Desc";
1024     return false;
1025   }
1026   auto tensor_size = SizeOf(shape) * GetDataTypeSize(dtype);
1027   if (tensor_size <= 0) {
1028     MS_LOG(INFO) << "Failed to calculate " << name << " tensor size, shape " << ShapeVectorToStr(shape)
1029                  << ", date type " << dtype;
1030     return false;
1031   }
1032   desc->SetPlacement(::ge::kPlacementDevice);
1033   auto ret = info.ge_tensor.SetTensorDesc(*desc);
1034   if (ret != ACL_ERROR_NONE) {
1035     MS_LOG(ERROR) << "Failed to call ge::Tensor::SetTensorDesc, ret " << ret;
1036     return false;
1037   }
1038   info.device_addr = memory_manager_->MallocDeviceMemory(name, tensor_size);
1039   if (info.device_addr == nullptr) {
1040     MS_LOG(ERROR) << "Failed to malloc device memory for " << name << ", memory size " << tensor_size
1041                   << ", tensor shape " << shape;
1042     return false;
1043   }
1044   ret = info.ge_tensor.SetData(reinterpret_cast<uint8_t *>(info.device_addr), tensor_size, [](uint8_t *) -> void {});
1045   if (ret != ge::GRAPH_SUCCESS) {
1046     MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << tensor_size;
1047     return false;
1048   }
1049   info.max_size = tensor_size;
1050   info.shape = shape;
1051   info.dtype = dtype;
1052   return true;
1053 }
1054 
UpdateInputShapeOption(const FuncGraphPtr & func_graph,const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_data_tensors,std::map<std::string,std::string> * ge_options_ptr)1055 bool GeGraphExecutor::UpdateInputShapeOption(
1056   const FuncGraphPtr &func_graph, const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors,
1057   std::map<std::string, std::string> *ge_options_ptr) {
1058   if (ge_options_ptr == nullptr) {
1059     MS_LOG(ERROR) << "Input argument ge_options_ptr cannot be nullptr";
1060     return false;
1061   }
1062   std::string input_shape_str;
1063   std::vector<GeDynamicShapeInfo> input_shapes;
1064   if (!GeDynamicUtils::GetGraphInputShapes(context_, config_infos_, &input_shapes, &input_shape_str)) {
1065     MS_LOG(ERROR) << "Failed to get input shape from AscendDeviceInfo or config file";
1066     return false;
1067   }
1068   std::map<std::string, std::string> shape_map;
1069   if (input_shapes.empty()) {
1070     MS_LOG(INFO) << "Not found input shape in AscendDeviceInfo or config file";
1071     if (!dyn_kv_cache_info_.dynamic_kv_cache) {
1072       return true;
1073     }
1074     auto inputs = func_graph->get_inputs();
1075     bool dyn_input = false;
1076     for (auto &item : inputs) {
1077       auto shape = FuncGraphUtils::GetTensorShape({item, 0});
1078       if (std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; })) {
1079         dyn_input = true;
1080       }
1081       shape_map[item->fullname_with_scope()] = lite::VectorToStrJoin(shape, ",");
1082     }
1083     if (!dyn_input) {
1084       MS_LOG(INFO) << "Current model has no dynamic inputs and there is no ge.inputShape set in config, skip update "
1085                       "ge.inputShape option for dynamic KVCache";
1086       return true;
1087     }
1088   } else {
1089     for (auto &item : input_shapes) {
1090       shape_map[item.name] = item.shape_str;
1091     }
1092   }
1093   for (auto &item : ref_data_tensors) {
1094     ShapeVector ref_dyn_shape = item.second->shape_c();
1095     SetRefShape(&ref_dyn_shape, true, item.first);
1096     shape_map[item.first] = lite::VectorToStrJoin(ref_dyn_shape, ",");
1097   }
1098   std::string new_input_shape_str = lite::MapToStrJoin(shape_map, ":", ";");
1099   GeDynamicUtils::UpdateGraphInputShapes(context_, &config_infos_, new_input_shape_str);
1100   (*ge_options_ptr)["ge.inputShape"] = new_input_shape_str;
1101   MS_LOG(INFO) << "Update ge.inputShape to " << new_input_shape_str;
1102   return true;
1103 }
1104 
InitRefDataContext(const FuncGraphPtr & func_graph,const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_data_tensors,std::map<std::string,std::string> * ge_options_ptr)1105 bool GeGraphExecutor::InitRefDataContext(const FuncGraphPtr &func_graph,
1106                                          const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors,
1107                                          std::map<std::string, std::string> *ge_options_ptr) {
1108   if (!UpdateInputShapeOption(func_graph, ref_data_tensors, ge_options_ptr)) {
1109     MS_LOG(ERROR) << "Failed to update input shape option";
1110     return false;
1111   }
1112   if (!InitRefDataList(ref_data_tensors)) {
1113     MS_LOG(ERROR) << "Failed to init ref data list";
1114     return false;
1115   }
1116   if (!InitMaxShapeParam()) {
1117     MS_LOG(ERROR) << "Failed to init max shape size";
1118     return false;
1119   }
1120   return true;
1121 }
1122 
CreateFakeGraph(const std::map<std::string,std::string> & ge_options)1123 transform::DfGraphPtr GeGraphExecutor::CreateFakeGraph(const std::map<std::string, std::string> &ge_options) {
1124   if (enable_update_weight_) {
1125     MS_LOG(INFO) << "Enable update weight, skip create small ge graph";
1126     return nullptr;
1127   }
1128   if (build_cache_dir_.empty()) {
1129     MS_LOG(INFO) << "Option model_cache_mode " << cache_mode_ << " is not mem_opt and not load offline model or "
1130                  << kGeGraphCompilerCacheDir << " is empty, skip create small ge graph";
1131     return nullptr;
1132   }
1133   auto graph_it = ge_options.find(kGeGraphKey);
1134   if (graph_it == ge_options.end()) {
1135     MS_LOG(INFO) << "Cannot find option " << kGeGraphKey << ", skip create small ge graph";
1136     return nullptr;
1137   }
1138   auto graph_key = graph_it->second;
1139   auto idx_file_name = build_cache_dir_ + "/" + graph_key + ".idx";
1140   if (!UpdateOmCacheIdxFile(idx_file_name)) {
1141     return nullptr;
1142   }
1143   auto df_graph = GenExampleGraph(graph_key);
1144   if (df_graph == nullptr) {
1145     MS_LOG(WARNING) << "Failed to create small ge graph for graph " << graph_key << ", skip create small ge graph";
1146     return nullptr;
1147   }
1148   MS_LOG(INFO) << "Create small  ge graph for graph " << graph_key;
1149   return df_graph;
1150 }
1151 
UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> & weights)1152 bool GeGraphExecutor::UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> &weights) {
1153   auto time1 = lite::GetTimeUs();
1154   if (init_graph_id_list_.empty()) {
1155     MS_LOG(ERROR) << "init graph id list is empty.";
1156     return false;
1157   }
1158   uint32_t init_graph_id = init_graph_id_list_[0];
1159   MS_LOG(INFO) << "init_graph_id: " << init_graph_id;
1160   if (update_weight_ptr_ == nullptr) {
1161     MS_LOG(ERROR) << "please init update weight class by build model.";
1162     return false;
1163   }
1164   std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> new_weight_tensors;
1165   auto ret = update_weight_ptr_->UpdateConstantTensorData(weights, &new_weight_tensors);
1166   if (!ret) {
1167     MS_LOG(ERROR) << "update weight failed.";
1168     return false;
1169   }
1170   MS_LOG(DEBUG) << "ExecInitGraph start.";
1171   auto time2 = lite::GetTimeUs();
1172   MS_LOG(INFO) << "update weight prepare time: " << (time2 - time1) / kNumMicrosecondToMillisecond << " ms";
1173 
1174   // cppcheck-suppress cppcheckError
1175   for (size_t i = 0; i < new_weight_tensors.size(); i++) {
1176     std::vector<::ge::Tensor> ge_inputs;
1177     // cppcheck-suppress cppcheckError
1178     for (size_t j = 0; j < new_weight_tensors[i].size(); j++) {
1179       auto &input = new_weight_tensors[i][j];
1180       auto ge_tensor = transform::TransformUtil::ConvertTensor(input, kOpFormat_NCHW, false);
1181       if (ge_tensor == nullptr) {
1182         MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1183         return false;
1184       }
1185       ge_inputs.emplace_back(*ge_tensor);
1186     }
1187     std::vector<::ge::Tensor> ge_outputs;
1188     auto ge_status = ge_session_->RunGraph(init_graph_id, ge_inputs, ge_outputs);
1189     if (ge_status != ge::GRAPH_SUCCESS) {
1190       MS_LOG(ERROR) << "Exec init graph failed, graph id " << init_graph_id;
1191       return false;
1192     }
1193   }
1194   auto time3 = lite::GetTimeUs();
1195   MS_LOG(INFO) << "update weight run init graph time: " << (time3 - time2) / kNumMicrosecondToMillisecond << " ms";
1196   return true;
1197 }
1198 
CreateGeGraphOnline(const FuncGraphPtr & anf_graph,std::map<std::string,std::string> * ge_options_ptr)1199 transform::DfGraphPtr GeGraphExecutor::CreateGeGraphOnline(const FuncGraphPtr &anf_graph,
1200                                                            std::map<std::string, std::string> *ge_options_ptr) {
1201   std::vector<std::string> extra_variables_names = {};
1202   if (enable_update_weight_ && update_weight_ptr_ != nullptr) {
1203     auto ret = update_weight_ptr_->CreateAddOpNodeForGraph(anf_graph);
1204     if (!ret) {
1205       MS_LOG(ERROR) << "CreateAddOpNodeForGraph failed.";
1206       return nullptr;
1207     }
1208     extra_variables_names = update_weight_ptr_->GetVariableParamsName(anf_graph);
1209     if (extra_variables_names.empty()) {
1210       MS_LOG(WARNING) << "GetVariableParamsName failed.";
1211       return nullptr;
1212     }
1213   }
1214   transform::TensorOrderMap params_vals;
1215   GetParams(anf_graph, &params_vals);
1216   transform::SetDynRefDataFunc dyn_ref_data_func = nullptr;
1217   if (dyn_kv_cache_info_.dynamic_kv_cache) {
1218     dyn_ref_data_func = [this](const AnfNodePtr &node, const ShapeVector &org_shape) -> ShapeVector {
1219       return SetKVCacheShape(dyn_kv_cache_info_.batch_size_dyn, dyn_kv_cache_info_.seq_length_dyn,
1220                              dyn_kv_cache_info_.kv_cache_layout, org_shape);
1221     };
1222   }
1223 
1224   MS_LOG(INFO) << "extra_variables_names size: " << extra_variables_names.size();
1225   auto converter = std::make_shared<transform::DfGraphConvertor>(anf_graph, "", ref_mode_flag_, extra_variables_names,
1226                                                                  dyn_ref_data_func);
1227   transform::BuildGraph(graph_name_, converter, params_vals);
1228   auto err_code = transform::ErrCode(converter);
1229   if (err_code != 0) {
1230     transform::ClearGraph();
1231     MS_LOG(ERROR) << "Convert df graph failed, err:" << err_code;
1232     return nullptr;
1233   }
1234   auto init_graph = transform::GetInitGraph(converter);
1235   if (init_graph != nullptr) {
1236     uint32_t init_graph_id = 0;
1237     if (!AddGraph(init_graph, {}, &init_graph_id)) {
1238       MS_LOG(ERROR) << "Failed to add init graph, graph name " << anf_graph->ToString();
1239       return nullptr;
1240     }
1241     if (enable_update_weight_ && update_weight_ptr_ != nullptr) {
1242       init_graph_id_list_.push_back(init_graph_id);
1243     }
1244     auto init_data_names = converter->GetInitDataNames();
1245     if (enable_update_weight_ && update_weight_ptr_ != nullptr) {
1246       if (!update_weight_ptr_->SetInitDataNames(init_data_names)) {
1247         MS_LOG(ERROR) << "set init data name failed.";
1248         return nullptr;
1249       }
1250     }
1251     // copy init weight to device
1252     if (!RunGeInitGraph(init_graph_id, init_data_names, params_vals)) {
1253       MS_LOG(ERROR) << "Failed to run init graph for " << anf_graph->ToString();
1254       return nullptr;
1255     }
1256     if (!enable_update_weight_) {
1257       ge_session_->RemoveGraph(init_graph_id);
1258     }
1259   } else {
1260     MS_LOG(INFO) << "There is no init graph for graph " << anf_graph->ToString();
1261   }
1262   if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
1263     auto ref_data_names = converter->GetRefDataNames();
1264     std::vector<std::pair<std::string, tensor::TensorPtr>> ref_datas;
1265     std::transform(ref_data_names.begin(), ref_data_names.end(), std::back_inserter(ref_datas),
1266                    [&params_vals](auto &item) { return std::make_pair(item, params_vals.at(item)); });
1267     if (!InitRefDataContext(anf_graph, ref_datas, ge_options_ptr)) {
1268       MS_LOG(ERROR) << "Failed to init refdata context";
1269       return nullptr;
1270     }
1271   }
1272   auto df_graph = transform::GetComputeGraph(converter);
1273   return df_graph;
1274 }
1275 
SetOptionsIntoOfflineModel(const std::map<std::string,std::string> & graph_options,std::map<std::string,ValuePtr> * attr_map_ptr)1276 void GeGraphExecutor::SetOptionsIntoOfflineModel(const std::map<std::string, std::string> &graph_options,
1277                                                  std::map<std::string, ValuePtr> *attr_map_ptr) {
1278   auto &attr_map = *attr_map_ptr;
1279 
1280   if (!build_cache_relative_dir_.empty()) {
1281     attr_map[lite::kNameAttrWeightDir] = MakeValue(build_cache_relative_dir_);
1282     MS_LOG(INFO) << "Set graph attr " << lite::kNameAttrWeightDir << " to " << build_cache_relative_dir_;
1283   }
1284   // ge session options
1285   auto find_set_option = [](const std::map<std::string, std::string> &from_options,
1286                             std::vector<std::string> *to_options, const std::string &option) {
1287     auto config_it = from_options.find(option);
1288     if (config_it != from_options.end()) {
1289       to_options->push_back(option);
1290       to_options->push_back(config_it->second);
1291     }
1292   };
1293   std::vector<std::string> session_save_options;
1294   find_set_option(session_options_, &session_save_options, "ge.externalWeight");
1295   attr_map[lite::kGeSessionOptionsSection] = MakeValue(session_save_options);
1296 
1297   std::vector<std::string> graph_save_options;
1298   find_set_option(graph_options, &graph_save_options, "ge.inputShape");
1299   find_set_option(graph_options, &graph_save_options, "ge.dynamicDims");
1300   find_set_option(graph_options, &graph_save_options, "ge.dynamicNodeType");
1301   attr_map[lite::kGeGraphOptionsSection] = MakeValue(graph_save_options);
1302 }
1303 
LoadOnlineGraph(const FuncGraphPtr & anf_graph,uint32_t * graph_id)1304 bool GeGraphExecutor::LoadOnlineGraph(const FuncGraphPtr &anf_graph, uint32_t *graph_id) {
1305   std::map<std::string, std::string> extra_session_options;
1306   if (!cache_mode_.empty()) {
1307     if (!SetModelCacheDir(&extra_session_options)) {
1308       return false;
1309     }
1310   }
1311   if (!CreateSession(extra_session_options)) {
1312     MS_LOG(ERROR) << "Failed to create ge session";
1313     return false;
1314   }
1315   std::map<std::string, std::string> ge_options;
1316   GetGeGraphOptions(anf_graph, &ge_options);
1317   auto df_graph = CompileGraphCommon(anf_graph, &ge_options);
1318   if (df_graph == nullptr) {
1319     MS_LOG(ERROR) << "Input param graph is nullptr.";
1320     return false;
1321   }
1322   if (cache_mode_ == "mem_opt") {
1323     auto fake_df_graph = CreateFakeGraph(ge_options);
1324     if (fake_df_graph != nullptr) {
1325       df_graph = fake_df_graph;
1326     }
1327   }
1328   if (!AddGraph(df_graph, ge_options, graph_id)) {
1329     MS_LOG(ERROR) << "Failed to add compute graph, graph name " << anf_graph->ToString();
1330     return false;
1331   }
1332   return true;
1333 }
1334 
CompileGraphCommon(const FuncGraphPtr & anf_graph,std::map<std::string,std::string> * ge_options_ptr)1335 transform::DfGraphPtr GeGraphExecutor::CompileGraphCommon(const FuncGraphPtr &anf_graph,
1336                                                           std::map<std::string, std::string> *ge_options_ptr) {
1337   if (anf_graph == nullptr) {
1338     MS_LOG(ERROR) << "Input param graph is nullptr.";
1339     return nullptr;
1340   }
1341   if (ge_options_ptr == nullptr) {
1342     MS_LOG(ERROR) << "Input param ge_options_ptr is nullptr.";
1343     return nullptr;
1344   }
1345 
1346 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
1347   auto param = ParseGraphKernelConfigs(config_infos_);
1348   if (param != nullptr) {
1349     auto rank_id = common::GetEnv("RANK_ID");
1350     if (rank_id.empty()) {
1351       auto ascend_device_info = GeUtils::GetAscendDeviceInfo(context_);
1352       if (ascend_device_info != nullptr) {
1353         auto rank_id_value = ascend_device_info->GetRankID();
1354         common::SetEnv("RANK_ID", std::to_string(rank_id_value).c_str());
1355       }
1356     }
1357     if (GraphKernelOptimize(anf_graph, param) != lite::RET_OK) {
1358       MS_LOG(ERROR) << "Run graphkernel optimization failed.";
1359       return nullptr;
1360     }
1361   }
1362 #endif
1363 
1364   auto remove_load_pass = std::make_shared<opt::RemoveLoadPass>();
1365   remove_load_pass->Run(anf_graph);
1366 
1367   if (!UpdateGraphInputs(anf_graph)) {
1368     MS_LOG(ERROR) << "Failed to update graph inputs";
1369     return nullptr;
1370   }
1371 
1372   opt::UpdateManager(anf_graph);
1373 
1374   // Convert mindir attributes to inputs because of dynamic_shape operator.
1375   // For the transformed operators, the GE adapter only supports inputs but not attributes.
1376   auto args_to_attr_pass = std::make_shared<opt::AttrToArgsPass>();
1377   if (args_to_attr_pass == nullptr) {
1378     MS_LOG(ERROR) << "create AttrToArgsPass failed";
1379     return nullptr;
1380   }
1381   if (!args_to_attr_pass->Run(anf_graph)) {
1382     MS_LOG(ERROR) << "convert args to attr pass failed";
1383     return nullptr;
1384   }
1385 
1386   transform::DfGraphPtr df_graph = nullptr;
1387   auto func_type = anf_graph->get_attr(kAttrFuncType);
1388   is_data_flow_graph_ = func_type != nullptr && GetValue<std::string>(func_type) == kDataFlowGraphType;
1389   if (!is_data_flow_graph_) {
1390     df_graph = CreateGeGraphOnline(anf_graph, ge_options_ptr);
1391   } else {
1392     df_graph = GetDataFlowGraph(anf_graph, *ge_options_ptr);
1393   }
1394   return df_graph;
1395 }
1396 
CompileGraph(const FuncGraphPtr & anf_graph,const std::map<string,string> &,uint32_t * graph_id)1397 bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &anf_graph, const std::map<string, string> &,
1398                                    uint32_t *graph_id) {
1399   MS_CHECK_TRUE_RET(graph_id != nullptr, false);
1400   uint32_t compute_graph_id = 0;
1401   if (CustomAscendUtils::IsCustomFuncGraph(anf_graph)) {
1402     MS_LOG(ERROR) << "Offline converted MindIR is not supported currently";
1403     return false;
1404   } else {
1405     auto ret = LoadOnlineGraph(anf_graph, &compute_graph_id);
1406     if (!ret) {
1407       MS_LOG(ERROR) << "Failed to load online model";
1408       return false;
1409     }
1410   }
1411   compute_graph_id_list_.push_back(compute_graph_id);
1412   *graph_id = compute_graph_id;
1413   if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
1414     if (!BuildGraphRefMode(anf_graph, compute_graph_id)) {
1415       MS_LOG(ERROR) << "Failed to build ge graph with refdata";
1416       return false;
1417     }
1418   }
1419   std::vector<tensor::TensorPtr> orig_output;
1420   std::vector<std::string> output_names;
1421   FuncGraphUtils::GetFuncGraphOutputsInfo(anf_graph, &orig_output, &output_names);
1422   original_graph_outputs_[*graph_id] = orig_output;
1423   return true;
1424 }
1425 
GetOneRealInputs(const FuncGraphPtr & anf_graph,std::vector<ge::Tensor> * ge_tensors_ptr)1426 bool GeGraphExecutor::GetOneRealInputs(const FuncGraphPtr &anf_graph, std::vector<ge::Tensor> *ge_tensors_ptr) {
1427   std::vector<std::pair<std::string, ShapeVector>> input_shapes_configs;
1428   std::string input_shape_str;
1429   if (!GeDynamicUtils::GetGraphOneRealShapes(context_, config_infos_, &input_shapes_configs, &input_shape_str)) {
1430     MS_LOG(ERROR) << "Failed to get one real input shape";
1431     return false;
1432   }
1433   std::vector<tensor::TensorPtr> inputs;
1434   std::vector<std::string> input_names;
1435   FuncGraphUtils::GetFuncGraphInputsInfo(anf_graph, &inputs, &input_names);
1436   if (!input_shapes_configs.empty() && input_shapes_configs.size() != inputs.size()) {
1437     MS_LOG(ERROR) << "Input count " << input_shapes_configs.size()
1438                   << " get from input_shape of AscendDeviceInfo or config file != input count " << inputs.size()
1439                   << " got from graph";
1440     return false;
1441   }
1442   std::vector<::ge::Tensor> ge_inputs;
1443   // cppcheck-suppress cppcheckError
1444   for (size_t i = 0; i < inputs.size(); i++) {
1445     auto &input = inputs[i];
1446     auto input_name = input_names[i];
1447     if (!input_shapes_configs.empty()) {
1448       auto it = std::find_if(input_shapes_configs.begin(), input_shapes_configs.end(),
1449                              [&input_name](const auto &item) { return input_name == item.first; });
1450       if (it == input_shapes_configs.end()) {
1451         MS_LOG(ERROR) << "Cannot find input " << input_name << " in input_shape " << input_shape_str;
1452         return false;
1453       }
1454       input = std::make_shared<tensor::Tensor>(input->data_type(), it->second);
1455     } else if (GeDynamicUtils::IsDynamicInputShapes({input->shape_c()})) {
1456       MS_LOG(ERROR) << "Input " << i << " is dynamic shape " << input->shape_c()
1457                     << ", but there is no input shape specified in AscendDeviceInfo or config file";
1458       return false;
1459     }
1460     MS_LOG(INFO) << "Input " << i << " shape " << input->shape_c() << ", datatype " << input->data_type();
1461     auto ge_tensor = transform::TransformUtil::ConvertTensor(input, kOpFormat_NCHW);
1462     if (ge_tensor == nullptr) {
1463       MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1464       return false;
1465     }
1466     ge_inputs.emplace_back(*ge_tensor);
1467   }
1468   *ge_tensors_ptr = ge_inputs;
1469   return true;
1470 }
1471 
AoeTuning(const FuncGraphPtr & anf_graph)1472 bool GeGraphExecutor::AoeTuning(const FuncGraphPtr &anf_graph) {
1473   if (!CreateSession({})) {
1474     MS_LOG(ERROR) << "Failed to create ge session";
1475     return false;
1476   }
1477   std::map<std::string, std::string> ge_options;
1478   GetGeGraphOptions(anf_graph, &ge_options);
1479   auto df_graph = CompileGraphCommon(anf_graph, &ge_options);
1480   if (df_graph == nullptr) {
1481     MS_LOG(ERROR) << "Input param graph is nullptr.";
1482     return false;
1483   }
1484   std::vector<::ge::Tensor> ge_inputs;
1485   if (!GetOneRealInputs(anf_graph, &ge_inputs)) {
1486     MS_LOG(ERROR) << "Failed to get one real inputs";
1487     return false;
1488   }
1489   AoeApiTuning tuning;
1490   auto status = tuning.AoeTurningGraph(ge_session_, df_graph, ge_inputs, context_, config_infos_);
1491   if (status != kSuccess) {
1492     MS_LOG(ERROR) << "Failed to call AoeTurningGraph";
1493     return false;
1494   }
1495   return true;
1496 }
1497 
RunGeInitGraph(uint32_t init_graph_id,const std::vector<std::string> & init_data_names,const transform::TensorOrderMap & params_vals)1498 bool GeGraphExecutor::RunGeInitGraph(uint32_t init_graph_id, const std::vector<std::string> &init_data_names,
1499                                      const transform::TensorOrderMap &params_vals) {
1500   std::vector<tensor::TensorPtr> init_data_tensors;
1501   for (auto &item : init_data_names) {
1502     auto it = params_vals.find(item);
1503     if (it == params_vals.end()) {
1504       MS_LOG(ERROR) << "Cannot find parameter " << item << " in parameter map";
1505       return false;
1506     }
1507     init_data_tensors.push_back(it->second);
1508   }
1509   MS_LOG(DEBUG) << "ExecInitGraph start.";
1510   std::vector<::ge::Tensor> ge_inputs;
1511   for (size_t i = 0; i < init_data_tensors.size(); i++) {
1512     auto &input = init_data_tensors[i];
1513     auto ge_tensor = transform::TransformUtil::ConvertTensor(input, kOpFormat_NCHW, false);
1514     if (ge_tensor == nullptr) {
1515       MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1516       return false;
1517     }
1518     ge_inputs.emplace_back(*ge_tensor);
1519   }
1520   std::vector<::ge::Tensor> ge_outputs;
1521   auto ge_status = ge_session_->RunGraph(init_graph_id, ge_inputs, ge_outputs);
1522   if (ge_status != ge::GRAPH_SUCCESS) {
1523     MS_LOG(ERROR) << "Exec init graph failed, graph id " << init_graph_id;
1524     return false;
1525   }
1526   MS_LOG(INFO) << "Exec init graph success, graph id " << init_graph_id;
1527   return true;
1528 }
1529 
RunGeGraphAsync(uint32_t graph_id,const std::vector<::ge::Tensor> & inputs,std::vector<::ge::Tensor> * outputs)1530 bool GeGraphExecutor::RunGeGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs,
1531                                       std::vector<::ge::Tensor> *outputs) {
1532   bool is_finished = false;
1533   bool end_of_sequence = false;
1534   std::promise<void> promise;
1535   auto call_back = [outputs, &is_finished, &end_of_sequence, &promise](ge::Status ge_status,
1536                                                                        const std::vector<ge::Tensor> &ge_outputs) {
1537     if (ge_status == ge::GRAPH_SUCCESS) {
1538       *outputs = ge_outputs;
1539       is_finished = true;
1540     } else if (ge_status == ge::END_OF_SEQUENCE) {
1541       MS_LOG(ERROR) << "RunAsync out of range: End of sequence.";
1542       end_of_sequence = true;
1543     } else {
1544       MS_LOG(ERROR) << "RunAsync failed." << ge::GEGetErrorMsg();
1545     }
1546     promise.set_value();
1547     return;
1548   };
1549   if (ge_session_ == nullptr) {
1550     MS_LOG(ERROR) << "The GE session is null, can't run the graph!";
1551     return false;
1552   }
1553   ge::Status ret = ge_session_->RunGraphAsync(graph_id, inputs, call_back);
1554   if (ret != ge::GRAPH_SUCCESS) {
1555     MS_LOG(ERROR) << "Call GE RunGraphAsync Failed: " << ge::GEGetErrorMsg();
1556     return false;
1557   }
1558   auto future = promise.get_future();
1559   future.wait();
1560   if (end_of_sequence) {
1561     MS_LOG(ERROR) << "Failed to call GE RunGraphAsync: End of sequence";
1562     return false;
1563   }
1564   return is_finished;
1565 }
1566 
RunDataFlowGraphAsync(uint32_t graph_id,const std::vector<::ge::Tensor> & inputs,std::vector<::ge::Tensor> * outputs)1567 bool GeGraphExecutor::RunDataFlowGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs,
1568                                             std::vector<::ge::Tensor> *outputs) {
1569   ge::DataFlowInfo data_flow_info;
1570   int time_out = 3000;  // set the timeout to 3000s.
1571   auto ret = ge_session_->FeedDataFlowGraph(graph_id, inputs, data_flow_info, time_out);
1572   if (ret != ge::SUCCESS) {
1573     MS_LOG(ERROR) << "Feed input data failed.";
1574     return false;
1575   }
1576   ret = ge_session_->FetchDataFlowGraph(graph_id, *outputs, data_flow_info, time_out);
1577   if (ret != ge::SUCCESS) {
1578     MS_LOG(ERROR) << "Fetch output data failed.";
1579     return false;
1580   }
1581   return true;
1582 }
1583 
InitInputDataTensor(const std::vector<tensor::Tensor> & inputs,std::vector<::ge::Tensor> * ge_inputs,std::vector<::ge::Tensor> * ge_outputs)1584 bool GeGraphExecutor::InitInputDataTensor(const std::vector<tensor::Tensor> &inputs,
1585                                           std::vector<::ge::Tensor> *ge_inputs, std::vector<::ge::Tensor> *ge_outputs) {
1586   if (inputs_buffer_infos_.size() != inputs.size()) {
1587     MS_LOG(ERROR) << "Input data info size " << inputs_buffer_infos_.size() << " != inputs size " << inputs.size();
1588     return false;
1589   }
1590   if (memory_manager_ == nullptr) {
1591     MS_LOG(ERROR) << "Memory manager or context manager is nullptr";
1592     return false;
1593   }
1594   for (size_t i = 0; i < inputs.size(); i++) {
1595     auto &input = inputs[i];
1596     MS_LOG(INFO) << "Input " << i << " shape " << tensor::ShapeToString(input.shape_c()) << ", datatype "
1597                  << input.data_type();
1598     auto tensor_size = input.Size();
1599     auto &input_info = inputs_buffer_infos_[i];
1600     if (input_info.max_size < tensor_size) {
1601       MS_LOG(ERROR) << "Input " << i << " data size invalid, graph size " << input_info.max_size << ", given size "
1602                     << tensor_size;
1603       return false;
1604     }
1605     if (!memory_manager_->MemcpyHost2Device(input_info.device_addr, input_info.max_size, input.data_c(), tensor_size)) {
1606       return false;
1607     }
1608 
1609     SetGeTensorShape(&input_info.ge_tensor, input.shape_c());
1610     ge_inputs->push_back(input_info.ge_tensor);
1611   }
1612   for (auto &item : ref_data_infos_) {
1613     if (dyn_kv_cache_info_.dynamic_kv_cache) {
1614       ShapeVector ref_real_shape = transform::TransformUtil::ConvertGeShape(item.ge_tensor.GetTensorDesc().GetShape());
1615       SetRefShape(&ref_real_shape, false, item.name);
1616       SetGeTensorShape(&item.ge_tensor, ref_real_shape);
1617       MS_LOG(INFO) << "Update RefData Input " << item.name << " shape to " << tensor::ShapeToString(ref_real_shape);
1618     }
1619     ge_inputs->push_back(item.ge_tensor);
1620   }
1621   if (!dyn_kv_cache_info_.is_ge_graph_static_) {
1622     ge_outputs->resize(outputs_buffer_infos_.size());
1623     for (auto &ge_tensor : *ge_outputs) {
1624       auto ret = ge_tensor.SetData(nullptr, 0U, [](uint8_t *) -> void {});
1625       if (ret != ge::GRAPH_SUCCESS) {
1626         MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(nullptr, 0, DeleteFunc) for output";
1627         return false;
1628       }
1629     }
1630   } else {
1631     for (auto &output : outputs_buffer_infos_) {
1632       ge_outputs->push_back(output.ge_tensor);
1633     }
1634   }
1635   return true;
1636 }
1637 
BuildGraphRefMode(const FuncGraphPtr & anf_graph,uint32_t graph_id)1638 bool GeGraphExecutor::BuildGraphRefMode(const FuncGraphPtr &anf_graph, uint32_t graph_id) {
1639   MS_LOG(INFO) << "Call GE CompileGraph start, graph id " << graph_id;
1640   ge::Status ret = ge_session_->CompileGraph(graph_id);
1641   if (ret != ge::GRAPH_SUCCESS) {
1642     MS_LOG(ERROR) << "Call GE CompileGraph Failed: " << ge::GEGetErrorMsg();
1643     return false;
1644   }
1645   MS_LOG(INFO) << "Call GE CompileGraph end, graph id " << graph_id;
1646   if (!InitMemoryContextManager()) {
1647     return false;
1648   }
1649 
1650   if (!InitRefDataDeviceTensor()) {
1651     MS_LOG(ERROR) << "Failed to init ref data device data";
1652     return false;
1653   }
1654 
1655   // ref data input memories have been allocated
1656   // for input data memory
1657   if (!InitInputDeviceTensor(anf_graph)) {
1658     MS_LOG(ERROR) << "Failed to init input data device data";
1659     return false;
1660   }
1661 
1662   // for output memory
1663   if (!InitOutputDeviceTensor(anf_graph, graph_id)) {
1664     MS_LOG(ERROR) << "Failed to init input data device data";
1665     return false;
1666   }
1667   return true;
1668 }
1669 
RunGraphRefMode(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs)1670 bool GeGraphExecutor::RunGraphRefMode(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs,
1671                                       std::vector<tensor::Tensor> *outputs) {
1672   MS_LOG(INFO) << "RunGraphRefMode begin";
1673   std::vector<::ge::Tensor> ge_inputs;
1674   std::vector<::ge::Tensor> ge_outputs;
1675   if (!InitRealShapeParam(inputs)) {
1676     return false;
1677   }
1678   if (!InitInputDataTensor(inputs, &ge_inputs, &ge_outputs)) {
1679     MS_LOG(ERROR) << "Init input tensor failed in run graph.";
1680     return false;
1681   }
1682   auto stream = context_manager_->GetDefaultStream();
1683   if (!RunGraphWithStreamAsync(graph_id, stream, ge_inputs, &ge_outputs)) {
1684     MS_LOG(ERROR) << "Failed in run graph with stream async.";
1685     return false;
1686   }
1687   if (!SyncDeviceOutputsToHost(outputs, &ge_outputs)) {
1688     MS_LOG(ERROR) << "Failed in sync device output to host.";
1689     return false;
1690   }
1691   MS_LOG(INFO) << "RunGraphRefMode end";
1692   return true;
1693 }
1694 
SyncDeviceOutputsToHost(std::vector<tensor::Tensor> * outputs,std::vector<::ge::Tensor> * ge_outputs)1695 bool GeGraphExecutor::SyncDeviceOutputsToHost(std::vector<tensor::Tensor> *outputs,
1696                                               std::vector<::ge::Tensor> *ge_outputs) {
1697   UpdateOutputShapeInfo(ge_outputs);
1698 
1699   size_t output_size = outputs_buffer_infos_.size();
1700   if (!outputs->empty()) {
1701     if (outputs->size() != output_size) {
1702       MS_LOG(ERROR) << "Invalid output size, outputs' size " << outputs->size() << "ge tensor size " << output_size;
1703       return false;
1704     }
1705     // cppcheck-suppress cppcheckError
1706     for (size_t i = 0; i < output_size; ++i) {
1707       auto &output_info = outputs_buffer_infos_[i];
1708       auto &output = (*outputs)[i];
1709       if (output.Size() < output_info.max_size) {
1710         MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.Size()
1711                           << " is less than actual output size " << output_info.max_size;
1712       }
1713       if ((*outputs)[i].data_c() == nullptr) {
1714         MS_LOG(ERROR) << "Output data ptr is nullptr.";
1715         return false;
1716       }
1717       auto mem_ret = memory_manager_->MemcpyDevice2Host(reinterpret_cast<uint8_t *>(output.data_c()), output.Size(),
1718                                                         output_info.device_addr, output_info.max_size);
1719       if (!mem_ret) {
1720         MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.Size()
1721                       << ", src size: " << output_info.max_size;
1722         return false;
1723       }
1724       MS_LOG(INFO) << "Output " << i << " shape " << tensor::ShapeToString(output_info.shape) << ", datatype "
1725                    << output_info.dtype;
1726     }
1727   } else {
1728     for (size_t i = 0; i < output_size; i++) {
1729       auto &output_info = outputs_buffer_infos_[i];
1730       tensor::Tensor ms_tensor(output_info.dtype, output_info.shape);
1731       auto mem_ret =
1732         memory_manager_->MemcpyDevice2Host(reinterpret_cast<uint8_t *>(ms_tensor.data_c()), ms_tensor.Size(),
1733                                            output_info.device_addr, output_info.max_size);
1734       if (!mem_ret) {
1735         MS_LOG(ERROR) << "Failed to copy output data, dst size: " << ms_tensor.Size()
1736                       << ", src size: " << output_info.max_size;
1737         return false;
1738       }
1739       MS_LOG(INFO) << "Output " << i << " shape " << tensor::ShapeToString(output_info.shape) << ", datatype "
1740                    << output_info.dtype;
1741       outputs->push_back(ms_tensor);
1742     }
1743   }
1744   return true;
1745 }
1746 
RunGraphWithStreamAsync(uint32_t graph_id,void * stream,const std::vector<GeTensor> & inputs,std::vector<GeTensor> * outputs)1747 bool GeGraphExecutor::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector<GeTensor> &inputs,
1748                                               std::vector<GeTensor> *outputs) {
1749   MS_EXCEPTION_IF_NULL(outputs);
1750   for (auto ge_input : inputs) {
1751     MS_LOG(INFO) << "In ge graph " << graph_id << ", input for RunGraphWithStreamAsync : "
1752                  << tensor::ShapeToString(
1753                       transform::TransformUtil::ConvertGeShape(ge_input.GetTensorDesc().GetShape()));
1754   }
1755   MS_LOG(INFO) << "Run the graph in GE with " << inputs.size() << " inputs";
1756   struct timeval start_time;
1757   (void)gettimeofday(&start_time, nullptr);
1758 
1759   ge::Status ret = ge_session_->RunGraphWithStreamAsync(graph_id, stream, inputs, *outputs);
1760   if (ret != ge::GRAPH_SUCCESS) {
1761     MS_LOG(ERROR) << "Call GE RunGraphWithStreamAsync Failed, ret is: " << ret;
1762     return false;
1763   }
1764   if (!context_manager_->SyncStream(stream)) {
1765     MS_LOG(ERROR) << "Sync stream for RunGraphWithStreamAsync failed";
1766     return false;
1767   }
1768   struct timeval end_time;
1769   (void)gettimeofday(&end_time, nullptr);
1770   const uint64_t kUSecondInSecond = 1000000;
1771   uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
1772   cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
1773   MS_LOG(INFO) << "Call GE RunGraphWithStreamAsync Success in " << cost << " us, GE outputs num: " << outputs->size()
1774                << ", graph id: " << graph_id;
1775 
1776   return true;
1777 }
1778 
RunGraph(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs,const std::map<string,string> &)1779 bool GeGraphExecutor::RunGraph(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs,
1780                                std::vector<tensor::Tensor> *outputs,
1781                                const std::map<string, string> & /* compile_options */) {
1782   if (outputs == nullptr) {
1783     MS_LOG(ERROR) << " Input param is nullptr.";
1784     return false;
1785   }
1786   MS_LOG(INFO) << "Run ge graph [" << graph_id << "] with " << inputs.size() << " inputs";
1787   for (size_t i = 0; i < inputs.size(); i++) {
1788     auto &input = inputs[i];
1789     MS_LOG(INFO) << "Input " << i << " shape " << input.shape_c() << ", datatype " << input.data_type();
1790   }
1791 
1792   if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
1793     return RunGraphRefMode(graph_id, inputs, outputs);
1794   }
1795   std::vector<::ge::Tensor> ge_inputs;
1796   for (size_t i = 0; i < inputs.size(); i++) {
1797     auto &input = inputs[i];
1798     auto ge_tensor =
1799       transform::TransformUtil::ConvertTensor(std::make_shared<tensor::Tensor>(input), kOpFormat_NCHW, false);
1800     if (ge_tensor == nullptr) {
1801       MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1802       return false;
1803     }
1804     ge_inputs.emplace_back(*ge_tensor);
1805   }
1806   for (auto &item : ref_data_infos_) {
1807     ge_inputs.emplace_back(item.ge_tensor);
1808   }
1809   std::vector<::ge::Tensor> ge_outputs;
1810   auto time_start = std::chrono::system_clock::now();
1811   auto ret = !is_data_flow_graph_ ? RunGeGraphAsync(graph_id, ge_inputs, &ge_outputs)
1812                                   : RunDataFlowGraphAsync(graph_id, ge_inputs, &ge_outputs);
1813   if (!ret) {
1814     MS_LOG(ERROR) << "Exec compute graph failed, graph id " << graph_id;
1815     return false;
1816   }
1817   auto time_cost =
1818     std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - time_start).count();
1819   MS_LOG(INFO) << "Call GE RunGraph Success in " << time_cost << " us, graph id " << graph_id
1820                << " the GE outputs num is: " << ge_outputs.size();
1821 
1822   if (!outputs->empty()) {
1823     if (outputs->size() != ge_outputs.size()) {
1824       MS_LOG(ERROR) << "Invalid output size, outputs' size " << outputs->size() << "ge tensor size "
1825                     << ge_outputs.size();
1826       return false;
1827     }
1828     for (size_t i = 0; i < outputs->size(); ++i) {
1829       const auto &tensor = ge_outputs[i];
1830       auto &output = (*outputs)[i];
1831       if (output.Size() < LongToSize(UlongToLong(tensor.GetSize()))) {
1832         MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.Size()
1833                           << " is less than actual output size " << tensor.GetSize();
1834       }
1835       if ((*outputs)[i].data_c() == nullptr) {
1836         MS_LOG(ERROR) << "Output data ptr is nullptr.";
1837         return false;
1838       }
1839       auto mem_ret = common::huge_memcpy(reinterpret_cast<uint8_t *>(output.data_c()), output.Size(), tensor.GetData(),
1840                                          tensor.GetSize());
1841       if (mem_ret != EOK) {
1842         MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.Size()
1843                       << ", src size: " << tensor.GetSize();
1844         return false;
1845       }
1846     }
1847   } else {
1848     for (size_t i = 0; i < ge_outputs.size(); i++) {
1849       auto &ge_tensor = ge_outputs[i];
1850       auto ms_tensor = ConvertGeTensorNoCopy(&ge_tensor, graph_id, i);
1851       if (ms_tensor == nullptr) {
1852         MS_LOG(ERROR) << "Failed to converter output " << i << " GE Tensor to ME Tensor";
1853         return false;
1854       }
1855       MS_LOG(INFO) << "Output " << i << " shape " << tensor::ShapeToString(ms_tensor->shape_c()) << ", datatype "
1856                    << ms_tensor->data_type();
1857       outputs->push_back(*ms_tensor);
1858     }
1859   }
1860   graph_inputs_[graph_id] = inputs;
1861   graph_outputs_[graph_id] = *outputs;
1862   MS_LOG(INFO) << "GE run graph " << graph_id << " end.";
1863   return true;
1864 }
1865 
GetInputInfos(uint32_t graph_id)1866 std::vector<tensor::Tensor> GeGraphExecutor::GetInputInfos(uint32_t graph_id) {
1867   return graph_inputs_.find(graph_id) != graph_inputs_.end() ? graph_inputs_.at(graph_id)
1868                                                              : std::vector<tensor::Tensor>();
1869 }
1870 
ConvertGeTensorNoCopy(::ge::Tensor * ge_tensor_ptr,uint32_t graph_id,size_t idx)1871 tensor::TensorPtr GeGraphExecutor::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr, uint32_t graph_id, size_t idx) {
1872   auto &ge_tensor = *ge_tensor_ptr;
1873   auto ge_tensor_desc = ge_tensor.GetTensorDesc();
1874   auto me_shape = transform::TransformUtil::ConvertGeShape(ge_tensor_desc.GetShape());
1875   if (original_graph_outputs_.find(graph_id) == original_graph_outputs_.end()) {
1876     MS_LOG(ERROR) << "Graph original outputs with the given graph id is not found.";
1877     return nullptr;
1878   }
1879   auto original_outputs = original_graph_outputs_[graph_id];
1880   if (idx >= original_outputs.size()) {
1881     MS_LOG(ERROR) << "Graph output index is out of range.";
1882     return nullptr;
1883   }
1884   TypeId type_id = static_cast<TypeId>(original_outputs[idx]->data_type_c());
1885   if (type_id == kTypeUnknown) {
1886     MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
1887                   << static_cast<int>(ge_tensor_desc.GetDataType());
1888     return nullptr;
1889   }
1890   if (ge_tensor_desc.GetPlacement() != ::ge::kPlacementHost) {
1891     MS_LOG(ERROR) << "It is not supported that graph output data's placement is device now.";
1892     return nullptr;
1893   }
1894   auto &&ge_data_uni = ge_tensor.ResetData();
1895   auto deleter = ge_data_uni.get_deleter();
1896   auto ge_data = ge_data_uni.release();
1897   if (ge_data == nullptr) {
1898     MS_LOG(ERROR) << "Ge data cannot be nullptr";
1899     return nullptr;
1900   }
1901   constexpr int64_t kTensorAlignBytes = 64;
1902   if (reinterpret_cast<uintptr_t>(ge_data) % kTensorAlignBytes != 0) {
1903     MS_LOG(ERROR) << "Skip zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data)
1904                   << ", bytes not aligned with expected.";
1905     return nullptr;
1906   }
1907   int64_t elem_num = 1;
1908   for (size_t i = 0; i < me_shape.size(); ++i) {
1909     elem_num *= me_shape[i];
1910   }
1911   if (GetTypeByte(TypeIdToType(type_id)) * elem_num != ge_tensor.GetSize()) {
1912     MS_LOG(ERROR) << "Output datatype error! Output tensor size from GE RunGraph does not match.";
1913     return nullptr;
1914   }
1915   auto tensor_data = std::make_shared<TensorRefData>(ge_data, elem_num, ge_tensor.GetSize(), me_shape.size(), deleter);
1916   return std::make_shared<tensor::Tensor>(type_id, me_shape, tensor_data);
1917 }
1918 
GetOutputInfos(uint32_t graph_id)1919 std::vector<tensor::Tensor> GeGraphExecutor::GetOutputInfos(uint32_t graph_id) {
1920   return graph_outputs_.find(graph_id) != graph_outputs_.end() ? graph_outputs_.at(graph_id)
1921                                                                : std::vector<tensor::Tensor>();
1922 }
1923 
CreateAsCustomFuncGraph(const FuncGraphPtr & func_graph,const std::map<std::string,std::string> & graph_options)1924 bool GeGraphExecutor::CreateAsCustomFuncGraph(const FuncGraphPtr &func_graph,
1925                                               const std::map<std::string, std::string> &graph_options) {
1926   Buffer buffer;
1927   auto files = ReadFileNames(build_cache_dir_);
1928   for (auto &file : files) {
1929     if (file.find(".om") != std::string::npos && file.find(graph_name_) != std::string::npos) {
1930       auto om_path = build_cache_dir_ + "/" + file;
1931       buffer = ReadFile(om_path);
1932       break;
1933     }
1934   }
1935   if (buffer.DataSize() == 0 || buffer.Data() == nullptr) {
1936     MS_LOG(ERROR) << "Failed to read model buffer file, model cache " << build_cache_dir_;
1937     return false;
1938   }
1939   std::map<std::string, ValuePtr> attr_map;
1940   SetOptionsIntoOfflineModel(session_options_, &attr_map);
1941   std::vector<std::string> ref_datas;
1942   std::transform(ref_data_infos_.begin(), ref_data_infos_.end(), std::back_inserter(ref_datas),
1943                  [](auto &item) { return item.name; });
1944   DynKVCacheSaveInfo save_info;
1945   save_info.seq_length_dyn = dyn_kv_cache_info_.seq_length_dyn;
1946   save_info.batch_size_dyn = dyn_kv_cache_info_.batch_size_dyn;
1947   save_info.kv_cache_layout = dyn_kv_cache_info_.kv_cache_layout;
1948 
1949   if (!CustomAscendUtils::CreateCustomFuncGraph(func_graph, buffer, graph_name_, attr_map, ref_datas, save_info)) {
1950     MS_LOG(ERROR) << "Create custom func graph failed";
1951     return false;
1952   }
1953   return true;
1954 }
1955 
OfflineBuildGraph(const FuncGraphPtr & graph)1956 bool GeGraphExecutor::OfflineBuildGraph(const FuncGraphPtr &graph) {
1957   if (ref_mode_flag_ == transform::RefModeFlag::kRefModeNone) {
1958     MS_LOG(INFO) << "parameter_as_refdata in ascend_context is none, skip offline build graph";
1959     return true;
1960   }
1961   MS_LOG(INFO) << "Set offline mode";
1962   std::map<std::string, std::string> extra_session_options;
1963   if (!SetOfflineBuildModelCacheDir(&extra_session_options)) {
1964     return false;
1965   }
1966   if (!CreateSession(extra_session_options)) {
1967     MS_LOG(ERROR) << "Failed to create ge session";
1968     return false;
1969   }
1970 
1971   if (!SetDynamicKVCache(graph)) {
1972     MS_LOG(ERROR) << "Failed to init dynamic KVCache info";
1973     return false;
1974   }
1975   uint32_t graph_id = 0;
1976   std::map<std::string, std::string> ge_options;
1977   GetGeGraphOptions(graph, &ge_options);
1978   auto df_graph = CompileGraphCommon(graph, &ge_options);
1979   if (df_graph == nullptr) {
1980     MS_LOG(ERROR) << "Input param graph is nullptr.";
1981     return false;
1982   }
1983   if (!AddGraph(df_graph, ge_options, &graph_id)) {
1984     MS_LOG(ERROR) << "Failed to add compute graph, graph name " << graph->ToString();
1985     return false;
1986   }
1987   compute_graph_id_list_.push_back(graph_id);
1988   MS_LOG(INFO) << "Call GE CompileGraph start, graph id " << graph_id;
1989   ge::Status ret = ge_session_->CompileGraph(graph_id);
1990   if (ret != ge::GRAPH_SUCCESS) {
1991     MS_LOG(ERROR) << "Call GE CompileGraph Failed: " << ge::GEGetErrorMsg();
1992     return false;
1993   }
1994   MS_LOG(INFO) << "Call GE CompileGraph end, graph id " << graph_id;
1995   if (!CreateAsCustomFuncGraph(graph, ge_options)) {
1996     MS_LOG(ERROR) << "Failed to CreateAsCustomFuncGraph";
1997     return false;
1998   }
1999   return true;
2000 }
2001 
2002 std::map<int64_t, std::shared_ptr<GeSessionContext>> GeSessionManager::ge_session_map_;
2003 std::mutex GeSessionManager::session_mutex_;
2004 
CreateGeSession(int64_t session_id,const std::map<std::string,std::string> & session_options)2005 std::shared_ptr<ge::Session> GeSessionManager::CreateGeSession(
2006   int64_t session_id, const std::map<std::string, std::string> &session_options) {
2007   std::shared_ptr<ge::Session> ge_session = nullptr;
2008   if (session_id == kUnkonwnSessionId) {
2009     ge_session = std::make_shared<ge::Session>(session_options);
2010     if (ge_session == nullptr) {
2011       MS_LOG(ERROR) << "Failed to create ge session";
2012       return nullptr;
2013     }
2014     MS_LOG(INFO) << "Create ge session successfully, which will not be shared with other graph";
2015     return ge_session;
2016   }
2017   std::lock_guard<std::mutex> lock(session_mutex_);
2018   auto s_it = ge_session_map_.find(session_id);
2019   if (s_it != ge_session_map_.end() && s_it->second != nullptr) {
2020     ge_session = s_it->second->ge_session.lock();
2021   }
2022   if (ge_session == nullptr) {
2023     for (auto &option : session_options) {
2024       MS_LOG(INFO) << "GE Session (lite session id " << session_id << ") option " << option.first << " = "
2025                    << option.second;
2026     }
2027     ge_session = std::make_shared<ge::Session>(session_options);
2028     if (ge_session == nullptr) {
2029       MS_LOG(ERROR) << "Failed to create ge session";
2030       return nullptr;
2031     }
2032     auto session_context = std::make_shared<GeSessionContext>();
2033     if (session_context == nullptr) {
2034       MS_LOG(ERROR) << "Failed to create GeSessionContext";
2035       return nullptr;
2036     }
2037     session_context->ge_session = ge_session;
2038     session_context->session_options = session_options;
2039     ge_session_map_[session_id] = session_context;
2040     MS_LOG(INFO) << "Create ge session successfully, lite session id: " << session_id;
2041   } else {
2042     auto map_as_string = [](const std::map<std::string, std::string> &options) {
2043       std::stringstream ss;
2044       ss << "{";
2045       for (auto &item : options) {
2046         ss << "" << item.first << ":" << item.second << ",";
2047       }
2048       ss << "}";
2049       return ss.str();
2050     };
2051     auto old_options = s_it->second->session_options;
2052     if (old_options != session_options) {
2053       MS_LOG(ERROR)
2054         << "Session options is not equal in diff config infos when models' weights are shared, last session options: "
2055         << map_as_string(old_options) << ", current session options: " << map_as_string(session_options);
2056       return nullptr;
2057     }
2058     MS_LOG(INFO) << "Get ge session from session map, lite session id: " << session_id;
2059   }
2060   return ge_session;
2061 }
2062 
UpdateSessionVariables(int64_t session_id,const std::vector<std::string> & graph_variables)2063 std::set<std::string> GeSessionManager::UpdateSessionVariables(int64_t session_id,
2064                                                                const std::vector<std::string> &graph_variables) {
2065   std::set<std::string> new_variables;
2066   if (session_id == kUnkonwnSessionId) {
2067     std::transform(graph_variables.begin(), graph_variables.end(), std::inserter(new_variables, new_variables.begin()),
2068                    [](const auto &item) { return item; });
2069     return new_variables;
2070   }
2071   std::lock_guard<std::mutex> lock(session_mutex_);
2072   std::shared_ptr<ge::Session> ge_session = nullptr;
2073   auto s_it = ge_session_map_.find(session_id);
2074   if (s_it != ge_session_map_.end() && s_it->second != nullptr) {
2075     ge_session = s_it->second->ge_session.lock();
2076   }
2077   if (ge_session == nullptr) {
2078     std::transform(graph_variables.begin(), graph_variables.end(), std::inserter(new_variables, new_variables.begin()),
2079                    [](const auto &item) { return item; });
2080     return new_variables;
2081   }
2082   auto &current_session_variables = s_it->second->session_variables;
2083   for (auto &item : graph_variables) {
2084     if (current_session_variables.find(item) == current_session_variables.end()) {
2085       new_variables.insert(item);
2086       current_session_variables.insert(item);
2087     }
2088   }
2089   return new_variables;
2090 }
2091 
TryReleaseGeSessionContext(int64_t session_id)2092 void GeSessionManager::TryReleaseGeSessionContext(int64_t session_id) {
2093   std::lock_guard<std::mutex> lock(session_mutex_);
2094   auto s_it = ge_session_map_.find(session_id);
2095   if (s_it != ge_session_map_.end()) {
2096     if (s_it->second != nullptr) {
2097       auto ge_session = s_it->second->ge_session.lock();
2098       if (ge_session == nullptr) {
2099         ge_session_map_.erase(s_it);
2100       }
2101     } else {
2102       ge_session_map_.erase(s_it);
2103     }
2104   }
2105 }
2106 
GetGeSessionContext(int64_t session_id)2107 std::shared_ptr<GeSessionContext> GeSessionManager::GetGeSessionContext(int64_t session_id) {
2108   std::lock_guard<std::mutex> lock(session_mutex_);
2109   auto s_it = ge_session_map_.find(session_id);
2110   if (s_it != ge_session_map_.end()) {
2111     return s_it->second;
2112   }
2113   return nullptr;
2114 }
2115 
GeGraphExecutorCreator(const std::shared_ptr<Context> & ctx,const ConfigInfos & config_infos)2116 static std::shared_ptr<device::GraphExecutor> GeGraphExecutorCreator(const std::shared_ptr<Context> &ctx,
2117                                                                      const ConfigInfos &config_infos) {
2118   auto ge_executor = std::make_shared<GeGraphExecutor>(ctx, config_infos);
2119   if (ge_executor == nullptr || !ge_executor->Init()) {
2120     MS_LOG(ERROR) << "Failed to init GeGraphExecutor";
2121     return nullptr;
2122   }
2123   return ge_executor;
2124 }
2125 
2126 REG_DELEGATE(kAscend, kProviderGe, GeGraphExecutorCreator)
2127 }  // namespace mindspore
2128