1 /**
2 * Copyright 2022-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "extendrt/delegate/ascend_ge/ge_graph_executor.h"
18 #include <tuple>
19 #include <algorithm>
20 #include <utility>
21 #include "mindspore/core/ops/framework_ops.h"
22 #include "extendrt/delegate/factory.h"
23 #include "include/common/utils/scoped_long_running.h"
24 #include "include/api/context.h"
25 #include "include/api/status.h"
26 #include "include/transform/graph_ir/utils.h"
27 #include "include/backend/device_type.h"
28 #include "runtime/device/ms_device_shape_transfer.h"
29 #include "src/common/common.h"
30 #include "src/common/file_utils.h"
31 #include "cxx_api/acl_utils.h"
32 #include "mindspore/core/utils/ms_utils_secure.h"
33 #include "tools/optimizer/common/gllo_utils.h"
34 #include "tools/optimizer/graph/remove_load_pass.h"
35 #include "src/extendrt/utils/func_graph_utils.h"
36 #include "transform/graph_ir/transform_util.h"
37 #include "flow_graph/data_flow.h"
38 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
39 #include "tools/graph_kernel/converter/graph_kernel_optimization.h"
40 #endif
41 #include "src/extendrt/utils/tensor_utils.h"
42 #include "external/ge_common/ge_api_error_codes.h"
43 #include "src/extendrt/delegate/ascend_ge/aoe_api_tune_process.h"
44 #include "extendrt/delegate/ascend_ge/ge_utils.h"
45 #include "extendrt/delegate/ascend_ge/ge_dynamic_utils.h"
46 #include "mindspore/core/ops/lite_ops.h"
47 #include "mindspore/core/ops/nn_optimizer_ops.h"
48 #include "mindspore/lite/tools/common/string_util.h"
49 #include "mindspore/lite/src/extendrt/cxx_api/file_utils.h"
50 #include "mindspore/core/ops/custom.h"
51 #include "mindspore/lite/src/common/common.h"
52 #include "mindspore/lite/tools/common/custom_ascend_utils.h"
53 #include "op_proto/inc/array_ops.h"
54 #include "op_proto/inc/elewise_calculation_ops.h"
55 #include "mindspore/lite/tools/optimizer/graph/attr_to_args_pass.h"
56 #include "mindspore/core/ops/nn_ops.h"
57 #include <nlohmann/json.hpp>
58
59 namespace mindspore {
60 namespace {
61 constexpr auto kProviderGe = "ge";
62 constexpr auto kDump = "dump";
63 constexpr auto kDumpMode = "dump_mode";
64 constexpr auto kProfiling = "profiler";
65 constexpr auto kDataFlowGraphType = "data_flow";
66 constexpr auto kCustomInputSize = 2;
67 constexpr auto kGraphKernelParam = "graph_kernel_param";
68 constexpr auto kUnkonwnSessionId = -1;
69 constexpr auto kRefModeNone = "none";
70 constexpr auto kRefModeVariable = "variable";
71 constexpr auto kRefModeAll = "all";
72 constexpr float kNumMicrosecondToMillisecond = 1000.0;
73 constexpr size_t kAlignRefData = 32;
74
ALIGN_UP_REF_DATA(size_t size)75 size_t ALIGN_UP_REF_DATA(size_t size) {
76 return ((size + kMemAlignSize + kAlignRefData - 1) / kMemAlignSize) * kMemAlignSize;
77 }
78
79 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
ParseGraphKernelConfigs(const ConfigInfos & maps)80 std::shared_ptr<ConverterPara> ParseGraphKernelConfigs(const ConfigInfos &maps) {
81 if (maps.find(kGraphKernelParam) == maps.end()) {
82 return nullptr;
83 }
84 auto param = std::make_shared<ConverterPara>();
85 const auto &gk_map = maps.at(kGraphKernelParam);
86 std::stringstream oss;
87 for (const auto &item : gk_map) {
88 oss << "--" << item.first << "=" << item.second << " ";
89 }
90 param->device = GetSocVersion();
91 param->graphKernelParam.graph_kernel_flags = oss.str();
92 return param;
93 }
94 #endif
95
GenExampleGraph(const std::string & name)96 transform::DfGraphPtr GenExampleGraph(const std::string &name) {
97 MS_LOG(INFO) << "Gen fake graph name is " << name;
98 auto graph = std::make_shared<transform::DfGraph>(name);
99 auto shape_data = std::vector<int64_t>({1, 1, 1, 1});
100 transform::GeTensorDesc desc_data(ge::Shape(shape_data), ge::FORMAT_ND, ge::DT_FLOAT16);
101 auto data = ge::op::Data("data");
102 data.set_attr_index(0);
103 data.update_input_desc_x(desc_data);
104 data.update_output_desc_y(desc_data);
105 auto add = ge::op::Add("add").set_input_x1(data).set_input_x2(data);
106 std::vector<transform::Operator> inputs{data};
107 std::vector<transform::Operator> outputs{add};
108 graph->SetInputs(inputs);
109 graph->SetOutputs(outputs);
110 return graph;
111 }
112
UpdateOmCacheIdxFile(const std::string & idx_file_name)113 bool UpdateOmCacheIdxFile(const std::string &idx_file_name) {
114 std::ifstream ifs(idx_file_name);
115 if (!ifs.good() || !ifs.is_open()) {
116 MS_LOG(INFO) << "model cache idx json not exists, idx file: " << idx_file_name << ", skip create small ge graph";
117 return false;
118 }
119 nlohmann::json dump_cfg_json;
120 try {
121 dump_cfg_json = nlohmann::json::parse(ifs);
122 const std::string cache_file_list = "cache_file_list";
123 const std::string var_desc_file_name = "var_desc_file_name";
124 auto &cache_file_config = dump_cfg_json[cache_file_list];
125 if (cache_file_config == nullptr || cache_file_config[0] == nullptr) {
126 MS_LOG(WARNING) << "model cache idx json content invalid, idx file: " << idx_file_name
127 << ", skip create small ge graph";
128 return false;
129 }
130 auto &config = cache_file_config[0];
131 if (config[var_desc_file_name] != nullptr) {
132 config.erase(var_desc_file_name);
133 auto new_json_str = dump_cfg_json.dump(4);
134 ifs.close();
135 std::ofstream ofs(idx_file_name, std::ios::out);
136 if (!ofs.is_open()) {
137 MS_LOG(WARNING) << "Failed to open model cache idx file for write, idx file: " << idx_file_name
138 << ", skip create small ge graph";
139 return false;
140 }
141 ofs << new_json_str;
142 ofs.close();
143 #ifndef _MSC_VER
144 chmod(idx_file_name.c_str(), S_IRUSR);
145 #endif
146 MS_LOG(INFO) << "Erase option " << var_desc_file_name;
147 }
148 return true;
149 } catch (const std::exception &error) {
150 MS_LOG(WARNING) << "parse model cache idx json failed, idx file: " << idx_file_name
151 << ", skip create small ge graph";
152 return false;
153 }
154 }
155 } // namespace
156
157 std::atomic_uint32_t GeGraphExecutor::global_graph_idx_ = 0;
GetNextGraphIdx()158 uint32_t GeGraphExecutor::GetNextGraphIdx() { return global_graph_idx_++; }
GetDataFlowGraph(const FuncGraphPtr & anf_graph,const std::map<std::string,std::string> & ge_options)159 transform::DfGraphPtr GetDataFlowGraph(const FuncGraphPtr &anf_graph,
160 const std::map<std::string, std::string> &ge_options) {
161 MS_EXCEPTION_IF_NULL(anf_graph);
162 auto return_node = anf_graph->get_return();
163 MS_EXCEPTION_IF_NULL(return_node);
164 auto nodes = anf_graph->TopoSort(return_node);
165 auto itr = std::find_if(nodes.begin(), nodes.end(), [&](const AnfNodePtr &node) {
166 return node != nullptr && node->isa<CNode>() && opt::CheckPrimitiveType(node, prim::kPrimCustom);
167 });
168 if (itr == nodes.end()) {
169 MS_LOG(ERROR) << "The dataflow graph is invalid.";
170 return nullptr;
171 }
172 auto custom_cnode = (*itr)->cast<CNodePtr>();
173 MS_EXCEPTION_IF_NULL(custom_cnode);
174 if (custom_cnode->size() != kCustomInputSize) {
175 MS_LOG(ERROR) << "The input of dataflow custom node is not 2.";
176 return nullptr;
177 }
178 auto tensor = FuncGraphUtils::GetConstNodeValue(custom_cnode->input(1));
179 MS_EXCEPTION_IF_NULL(tensor);
180 auto data = tensor->data_c();
181 MS_EXCEPTION_IF_NULL(data);
182 auto flow_graph = reinterpret_cast<ge::dflow::FlowGraph *>(data);
183 MS_EXCEPTION_IF_NULL(flow_graph);
184 auto df_graph = std::make_shared<transform::DfGraph>(flow_graph->ToGeGraph());
185 return df_graph;
186 }
187
~GeGraphExecutor()188 GeGraphExecutor::~GeGraphExecutor() {
189 if (ge_session_) {
190 for (auto graph_id : init_graph_id_list_) {
191 ge_session_->RemoveGraph(graph_id);
192 }
193 for (auto graph_id : compute_graph_id_list_) {
194 ge_session_->RemoveGraph(graph_id);
195 auto session_context = GeSessionManager::GetGeSessionContext(session_id_);
196 if (session_context != nullptr) {
197 (void)session_context->feature_graph_ids.erase(graph_id);
198 }
199 }
200 ge_session_ = nullptr;
201 GeSessionManager::TryReleaseGeSessionContext(session_id_);
202 enable_update_weight_ = false;
203 update_weight_ptr_ = nullptr;
204 }
205 }
206
SetGeTensorShape(GeTensor * ge_tensor,ShapeVector shape)207 bool GeGraphExecutor::SetGeTensorShape(GeTensor *ge_tensor, ShapeVector shape) {
208 auto ge_desc = ge_tensor->GetTensorDesc();
209 ge::Shape new_ge_shape(shape);
210 ge_desc.Update(new_ge_shape);
211 ge_desc.SetOriginShape(new_ge_shape);
212 ge_tensor->SetTensorDesc(ge_desc);
213 MS_LOG(INFO) << "In SetGeTensorShape update ge shape to :" << shape;
214 return true;
215 }
216
InitInputDeviceTensor(const FuncGraphPtr & anf_graph)217 bool GeGraphExecutor::InitInputDeviceTensor(const FuncGraphPtr &anf_graph) {
218 MS_LOG(INFO) << "Call InitInputDeviceTensor start.";
219 auto inputs = anf_graph->get_inputs();
220 inputs_buffer_infos_.resize(inputs.size());
221 for (size_t i = 0; i < inputs.size(); i++) {
222 auto &input_info = inputs_buffer_infos_[i];
223 auto shape = FuncGraphUtils::GetTensorShape({inputs[i], 0});
224
225 /* set max_batch size and max_seq_len for dyn shape */
226 std::vector<int64_t> new_shape;
227 for (size_t j = 0; j < shape.size(); j++) {
228 if (shape[j] == abstract::Shape::kShapeDimAny) {
229 new_shape.push_back(dyn_kv_cache_info_.max_seq_len_size);
230 } else {
231 new_shape.push_back(shape[j]);
232 }
233 }
234
235 MS_LOG(INFO) << "Init input_" << i << " buffer for ge, change shape: " << shape << " -> " << new_shape;
236 auto dtype = static_cast<TypeId>(FuncGraphUtils::GetTensorDataType({inputs[i], 0}));
237 if (!InitInOutDeviceBuffer("Input " + std::to_string(i), new_shape, dtype, &input_info)) {
238 return false;
239 }
240 }
241 return true;
242 }
243
InitOutputDeviceTensor(const FuncGraphPtr & anf_graph,uint32_t graph_id)244 bool GeGraphExecutor::InitOutputDeviceTensor(const FuncGraphPtr &anf_graph, uint32_t graph_id) {
245 MS_LOG(INFO) << "Call GE GetCompiledGraphSummary start, graph id " << graph_id;
246 auto graph_summary = ge_session_->GetCompiledGraphSummary(graph_id);
247 if (graph_summary == nullptr) {
248 MS_LOG(ERROR) << "Failed to call GE GetCompiledGraphSummary, graph id " << graph_id
249 << ", error: " << ge::GEGetErrorMsg();
250 return false;
251 }
252 MS_LOG(INFO) << "Call GE GetCompiledGraphSummary end, graph id " << graph_id;
253 dyn_kv_cache_info_.is_ge_graph_static_ = graph_summary->IsStatic();
254 MS_LOG(INFO) << "GE graph is static :" << dyn_kv_cache_info_.is_ge_graph_static_ << ", graph id: " << graph_id;
255 std::vector<AnfWithOutIndex> outputs;
256 if (!FuncGraphUtils::GetFuncGraphOutputs(anf_graph, &outputs)) {
257 MS_LOG(ERROR) << "Failed to get func graph outputs";
258 return false;
259 }
260 outputs_buffer_infos_.resize(outputs.size());
261 if (dyn_kv_cache_info_.is_ge_graph_static_) {
262 std::vector<::ge::Shape> ge_shapes;
263 auto ge_status = graph_summary->GetOutputShapes(ge_shapes);
264 if (ge_status != ge::GRAPH_SUCCESS) {
265 MS_LOG(ERROR) << "Failed to call GetOutputShapes, status: " << ge_status;
266 return false;
267 }
268 if (outputs.size() != ge_shapes.size()) {
269 MS_LOG(ERROR) << "Output count got from graph " << outputs.size() << " != that " << ge_shapes.size()
270 << " got from GE";
271 return false;
272 }
273 for (size_t i = 0; i < outputs.size(); i++) {
274 auto &output_info = outputs_buffer_infos_[i];
275 auto shape = ge_shapes[i].GetDims();
276 auto dtype = static_cast<TypeId>(FuncGraphUtils::GetTensorDataType(outputs[i]));
277 if (!InitInOutDeviceBuffer("Output " + std::to_string(i), shape, dtype, &output_info)) {
278 return false;
279 }
280 }
281 }
282 return true;
283 }
284
SetRefShape(std::vector<int64_t> * ref_shape,bool dyn,std::string tensor_name)285 void GeGraphExecutor::SetRefShape(std::vector<int64_t> *ref_shape, bool dyn, std::string tensor_name) {
286 if (!dyn_kv_cache_info_.dynamic_kv_cache) {
287 return;
288 }
289 size_t b_index = kDim0;
290 size_t s_index = kDim2;
291 if (dyn_kv_cache_info_.kv_cache_layout == lite::kKVCacheLayoutBSH) {
292 s_index = kDim1;
293 }
294 if (dyn) {
295 if (dyn_kv_cache_info_.batch_size_dyn) {
296 (*ref_shape)[b_index] = abstract::Shape::kShapeDimAny;
297 MS_LOG(INFO) << "for " << tensor_name << " update batch size to dyn(-1) for ge_option.";
298 }
299 if (dyn_kv_cache_info_.seq_length_dyn) {
300 (*ref_shape)[s_index] = abstract::Shape::kShapeDimAny;
301 MS_LOG(INFO) << "for " << tensor_name << " update seq length size to dyn(-1) for ge_option.";
302 }
303 } else {
304 if (dyn_kv_cache_info_.batch_size_dyn) {
305 (*ref_shape)[b_index] = dyn_kv_cache_info_.real_batch_size;
306 MS_LOG(INFO) << "for " << tensor_name << " update batch size to " << dyn_kv_cache_info_.real_batch_size
307 << " for ge_option.";
308 }
309 if (dyn_kv_cache_info_.seq_length_dyn) {
310 (*ref_shape)[s_index] = dyn_kv_cache_info_.real_seq_len_size;
311 MS_LOG(INFO) << "for " << tensor_name << " update seq length size to " << dyn_kv_cache_info_.real_seq_len_size
312 << " for ge_option.";
313 }
314 }
315 }
316
UpdateOutputShapeInfo(std::vector<::ge::Tensor> * ge_outputs)317 void GeGraphExecutor::UpdateOutputShapeInfo(std::vector<::ge::Tensor> *ge_outputs) {
318 MS_LOG(INFO) << "Update output dtype and shape.";
319 for (size_t i = 0; i < outputs_buffer_infos_.size(); i++) {
320 auto &output_info = outputs_buffer_infos_[i];
321 auto &ge_output = ge_outputs->at(i);
322 auto ge_tensor_desc = ge_output.GetTensorDesc();
323 output_info.shape = transform::TransformUtil::ConvertGeShape(ge_tensor_desc.GetShape());
324 output_info.dtype = transform::TransformUtil::ConvertGeDataType(ge_tensor_desc.GetDataType());
325 output_info.max_size = SizeOf(output_info.shape) * GetDataTypeSize(output_info.dtype);
326 auto out_device = ge_output.GetData();
327 if (dyn_kv_cache_info_.is_ge_graph_static_ && out_device != output_info.device_addr) {
328 MS_LOG(WARNING) << "GE output device address not equal malloc device memory when graph is static";
329 }
330 output_info.device_addr = out_device;
331 MS_LOG(INFO) << "Update output_" << i << " dtype: " << output_info.dtype << ", shape: " << output_info.shape;
332 }
333 return;
334 }
335
SetDynamicKVCache(const FuncGraphPtr & func_graph)336 bool GeGraphExecutor::SetDynamicKVCache(const FuncGraphPtr &func_graph) {
337 auto graph_inputs = func_graph->get_inputs();
338 auto has_dynamic_input = std::any_of(graph_inputs.begin(), graph_inputs.end(), [](const AnfNodePtr &input) {
339 auto shape = FuncGraphUtils::GetTensorShape({input, 0});
340 return std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; });
341 });
342 if (!has_dynamic_input) {
343 MS_LOG(INFO) << "Not detect dynamic input in graph";
344 return true;
345 }
346 auto nodes = func_graph->TopoSort(func_graph->get_return());
347 if (nodes.empty()) {
348 MS_LOG(WARNING) << "There are no nodes in the graph";
349 return true;
350 }
351 constexpr size_t kv_index = 2; // primitive, kv cache, kv
352 for (auto &node : nodes) {
353 auto cnode = node->cast<CNodePtr>();
354 if (!cnode || !IsPrimitiveCNode(cnode, prim::kPrimPromptKVCache)) {
355 continue;
356 }
357 auto inputs = cnode->inputs();
358 if (inputs.size() <= kv_index) {
359 MS_LOG(WARNING) << "PrimPromptKVCache " << cnode->fullname_with_scope() << " input size " << inputs.size() - 1
360 << " <= kv index " << kv_index - 1;
361 continue;
362 }
363 auto kv_input = inputs[kv_index];
364 if (kv_input == nullptr) {
365 MS_LOG(WARNING) << "PrimPromptKVCache " << cnode->fullname_with_scope() << " kv input is nullptr";
366 continue;
367 }
368 if (!IsPrimitiveCNode(kv_input, prim::kPrimPadV3)) {
369 dyn_kv_cache_info_.dynamic_kv_cache = true;
370 dyn_kv_cache_info_.seq_length_dyn = true;
371 auto kv_shape = FuncGraphUtils::GetTensorShape({kv_input, 0});
372 if (kv_shape.size() == kShape4dDims) {
373 dyn_kv_cache_info_.kv_cache_layout = lite::kKVCacheLayoutBNSD;
374 } else if (kv_shape.size() == kShape3dDims) {
375 dyn_kv_cache_info_.kv_cache_layout = lite::kKVCacheLayoutBSH;
376 } else {
377 MS_LOG(ERROR) << "Expect RefData shape to be BNSD or BSH when dynamic kv cache is enable, but got " << kv_shape;
378 return false;
379 }
380 }
381 break;
382 }
383 MS_LOG(INFO) << "set dyn kv info dynamic_kv_cache : " << dyn_kv_cache_info_.dynamic_kv_cache;
384 MS_LOG(INFO) << "set dyn kv info seq_length_dyn : " << dyn_kv_cache_info_.seq_length_dyn;
385 return true;
386 }
387
CheckRefDataInfo()388 bool GeGraphExecutor::CheckRefDataInfo() {
389 if (!dyn_kv_cache_info_.dynamic_kv_cache) {
390 return true;
391 }
392 auto &ref_shape = ref_data_infos_.front().shape;
393 for (size_t i = 0; i < ref_data_infos_.size(); i++) {
394 auto &ref_data_info = ref_data_infos_[i];
395 auto ¶_name = ref_data_info.name;
396 if (dyn_kv_cache_info_.kv_cache_layout == lite::kKVCacheLayoutBSH) {
397 if (ref_data_info.shape.size() != kShape3dDims) {
398 MS_LOG(ERROR) << "KVCache shape size is not " << kShape3dDims << ", while KVCache layout is "
399 << dyn_kv_cache_info_.kv_cache_layout << ", KVCache param " << para_name << ", shape "
400 << ref_data_info.shape;
401 return false;
402 }
403 } else if (dyn_kv_cache_info_.kv_cache_layout == lite::kKVCacheLayoutBNSD) {
404 if (ref_data_info.shape.size() != kShape4dDims) {
405 MS_LOG(ERROR) << "KVCache shape size is not " << kShape4dDims << ", while KVCache layout is "
406 << dyn_kv_cache_info_.kv_cache_layout << ", KVCache param " << para_name << ", shape "
407 << ref_data_info.shape;
408 return false;
409 }
410 } else {
411 MS_LOG(ERROR) << "Unsupported KVCache layout " << dyn_kv_cache_info_.kv_cache_layout;
412 return false;
413 }
414 if (ref_shape != ref_data_info.shape) {
415 MS_LOG(ERROR) << "KVCache shape " << ref_data_info.shape << " of " << para_name << " != KVCache shape "
416 << ref_shape << " of " << ref_data_infos_.front().name;
417 return false;
418 }
419 }
420 return true;
421 }
422
InitMaxShapeParam()423 bool GeGraphExecutor::InitMaxShapeParam() {
424 if (ref_data_infos_.empty()) {
425 MS_LOG(INFO) << "RefData count is empty";
426 return true;
427 }
428 if (!CheckRefDataInfo()) {
429 return false;
430 }
431 auto &ref_shape = ref_data_infos_.front().shape;
432 size_t b_index = kDim0;
433 size_t s_index = kDim2;
434 if (ref_shape.size() == kShape3dDims) { // BSH
435 s_index = kDim1;
436 } else if (ref_shape.size() == kShape4dDims) { // BNSD
437 s_index = kDim2;
438 } else {
439 MS_LOG(WARNING) << "RefData dim count is unexpected, shape " << ref_shape << ", name "
440 << ref_data_infos_.front().name;
441 return true;
442 }
443 std::string max_batch_size;
444 if (GetConfigOption("ascend_context", "max_batch_size", &max_batch_size)) {
445 MS_LOG(INFO) << "Get max batch size from config file, ascend_context, max_batch_size";
446 dyn_kv_cache_info_.max_batch_size = std::stoi(max_batch_size);
447 } else {
448 MS_LOG(INFO) << "Get max batch size from ref data shape : " << ref_shape;
449 dyn_kv_cache_info_.max_batch_size = ref_shape[b_index];
450 }
451
452 std::string max_seq_length;
453 if (GetConfigOption("ascend_context", "max_seq_length", &max_seq_length)) {
454 MS_LOG(INFO) << "Get max seq length from config file, ascend_context, max_seq_length";
455 dyn_kv_cache_info_.max_seq_len_size = std::stoi(max_seq_length);
456 } else {
457 MS_LOG(INFO) << "Get max seq length from ref data shape : " << ref_shape;
458 dyn_kv_cache_info_.max_seq_len_size = ref_shape[s_index];
459 }
460
461 MS_LOG(INFO) << "set dynamic max shape, max batch size : " << dyn_kv_cache_info_.max_batch_size
462 << ", max seq length: " << dyn_kv_cache_info_.max_seq_len_size;
463 return true;
464 }
465
InitRealShapeParam(const std::vector<tensor::Tensor> & inputs)466 bool GeGraphExecutor::InitRealShapeParam(const std::vector<tensor::Tensor> &inputs) {
467 if (!dyn_kv_cache_info_.dynamic_kv_cache) {
468 return true;
469 }
470 auto input_0_shape = inputs[0].shape_c();
471 if (input_0_shape.size() != kShape2dDims) {
472 MS_LOG(ERROR) << "Expected input 0 shape to be [bs, seq_length], but got " << input_0_shape;
473 return false;
474 }
475 dyn_kv_cache_info_.real_batch_size = input_0_shape.at(Index0);
476 MS_LOG(INFO) << "Real batch size : " << dyn_kv_cache_info_.real_batch_size;
477 dyn_kv_cache_info_.real_seq_len_size = input_0_shape.at(Index1);
478 MS_LOG(INFO) << "Real seq length size : " << dyn_kv_cache_info_.real_seq_len_size;
479 return true;
480 }
481
GetConfigOption(const std::string & section_name,const std::string & option_name,std::string * option_val)482 bool GeGraphExecutor::GetConfigOption(const std::string §ion_name, const std::string &option_name,
483 std::string *option_val) {
484 if (option_val == nullptr) {
485 MS_LOG(ERROR) << "Input argument option_val is nullptr";
486 return false;
487 }
488 auto config_it = config_infos_.find(section_name);
489 if (config_it == config_infos_.end()) {
490 return false;
491 }
492 auto &options = config_it->second;
493 auto option_it = options.find(option_name);
494 if (option_it == options.end()) {
495 return false;
496 }
497 *option_val = option_it->second;
498 return true;
499 }
500
GetRankID() const501 uint32_t GeGraphExecutor::GetRankID() const {
502 auto ascend_info = GeUtils::GetAscendDeviceInfo(context_);
503 if (ascend_info == nullptr) {
504 MS_LOG(ERROR) << "Can not find ascend device context.";
505 return 0;
506 }
507 return ascend_info->GetRankID();
508 }
509
GetDeviceID() const510 uint32_t GeGraphExecutor::GetDeviceID() const {
511 auto ascend_info = GeUtils::GetAscendDeviceInfo(context_);
512 if (ascend_info == nullptr) {
513 MS_LOG(ERROR) << "Can not find ascend device context.";
514 return 0;
515 }
516 return ascend_info->GetDeviceID();
517 }
518
Init()519 bool GeGraphExecutor::Init() {
520 ge_global_context_ = GeDeviceContext::InitGlobalContext(context_, config_infos_);
521 if (ge_global_context_ == nullptr) {
522 MS_LOG(ERROR) << "Failed to Init global context";
523 return false;
524 }
525 if (!InitRefModeConfig()) {
526 return false;
527 }
528 std::string model_cache_mode;
529 (void)GetConfigOption(lite::kAscendContextSection, lite::kModelCacheMode, &model_cache_mode);
530 if (!model_cache_mode.empty()) {
531 cache_mode_ = model_cache_mode;
532 MS_LOG(INFO) << "Set set model cache mode " << model_cache_mode;
533 }
534 std::string variable_weights_list;
535 (void)GetConfigOption(lite::kAscendContextSection, "variable_weights_list", &variable_weights_list);
536 if (!variable_weights_list.empty()) {
537 update_weight_ptr_ = std::make_shared<UpdateWeight>();
538 if (update_weight_ptr_ == nullptr) {
539 MS_LOG(ERROR) << "init update weight ptr failed.";
540 return false;
541 }
542 if (!update_weight_ptr_->ParseUpdateWeightConfig(variable_weights_list)) {
543 MS_LOG(ERROR) << "ParseUpdateWeightConfig failed.";
544 update_weight_ptr_ = nullptr;
545 return false;
546 }
547 enable_update_weight_ = true;
548 }
549 return true;
550 }
551
InitRefModeConfig()552 bool GeGraphExecutor::InitRefModeConfig() {
553 std::string ref_mode;
554 (void)GetConfigOption(lite::kAscendContextSection, lite::kParameterAsRefData, &ref_mode);
555 if (!ref_mode.empty()) {
556 ref_mode = lite::StringTolower(ref_mode);
557 if (ref_mode != kRefModeNone && ref_mode != kRefModeVariable && ref_mode != kRefModeAll) {
558 MS_LOG(ERROR) << "Only " << kRefModeNone << ", " << kRefModeVariable << " or " << kRefModeAll
559 << " is supported for " << lite::kParameterAsRefData << ", but got " << ref_mode;
560 return false;
561 }
562 if (ref_mode == kRefModeAll) {
563 ref_mode_flag_ = transform::RefModeFlag::kRefModeAll;
564 } else if (ref_mode == kRefModeVariable) {
565 ref_mode_flag_ = transform::RefModeFlag::kRefModeVariable;
566 } else {
567 ref_mode_flag_ = transform::RefModeFlag::kRefModeNone;
568 }
569 MS_LOG(INFO) << "Set parameter ref mode " << ref_mode;
570 } else {
571 ref_mode_flag_ = transform::RefModeFlag::kRefModeNone;
572 }
573 return true;
574 }
575
GetGeSessionOptions(std::map<std::string,std::string> * ge_options_ptr)576 void GeGraphExecutor::GetGeSessionOptions(std::map<std::string, std::string> *ge_options_ptr) {
577 MS_EXCEPTION_IF_NULL(ge_options_ptr);
578 auto &ge_options = *ge_options_ptr;
579 ge_options["ge.trainFlag"] = "0";
580 ge_options["ge.enablePrintOpPass"] = "0";
581 ge_options["ge.exec.device_id"] = std::to_string(GetDeviceID());
582 ge_options["ge.exec.staticMemoryPolicy"] = "2";
583 if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
584 ge_options["ge.constLifecycle"] = "graph";
585 }
586 auto config_it = config_infos_.find(lite::kGeSessionOptionsSection);
587 if (config_it != config_infos_.end()) {
588 for (auto &item : config_it->second) {
589 ge_options[item.first] = item.second;
590 MS_LOG(INFO) << "Set ge session option " << item.first << " to " << item.second;
591 }
592 }
593 config_it = config_infos_.find(lite::kAscendContextSection);
594 if (config_it != config_infos_.end()) {
595 GetGeSessionOptionsFromAscendContext(config_it->second, ge_options_ptr);
596 }
597 }
598
SetModelCacheDir(std::map<std::string,std::string> * session_options_ptr)599 bool GeGraphExecutor::SetModelCacheDir(std::map<std::string, std::string> *session_options_ptr) {
600 auto &ge_options = *session_options_ptr;
601 auto build_cache_dir = "model_build_cache_" + std::to_string(GetRankID());
602 if (lite::CreateDir(build_cache_dir) != RET_OK) {
603 MS_LOG(ERROR) << "Failed to create build cache dir " << build_cache_dir;
604 return false;
605 }
606 ge_options[kGeGraphCompilerCacheDir] = build_cache_dir;
607 MS_LOG(INFO) << "Update session attr " << kGeGraphCompilerCacheDir << " to " << build_cache_dir;
608 return true;
609 }
610
SetOfflineBuildModelCacheDir(std::map<std::string,std::string> * session_options_ptr)611 bool GeGraphExecutor::SetOfflineBuildModelCacheDir(std::map<std::string, std::string> *session_options_ptr) {
612 std::string build_cache_dir;
613 auto &ge_options = *session_options_ptr;
614 bool build_cache_enabled = false;
615 std::string output_file;
616 (void)GetConfigOption(lite::kConverterParams, lite::kConverterOutputFile, &output_file);
617 std::string output_dir = "./";
618 if (output_file.find("/") != std::string::npos) {
619 output_dir = output_file.substr(0, output_file.rfind("/") + 1);
620 }
621 session_id_ = GetSessionId();
622 auto ge_session_context = GeSessionManager::GetGeSessionContext(session_id_);
623 if (ge_session_context) {
624 const auto &last_ge_options = ge_session_context->session_options;
625 if (auto it = last_ge_options.find(kGeGraphCompilerCacheDir); it != last_ge_options.end()) {
626 build_cache_dir = it->second;
627 build_cache_enabled = true;
628 }
629 }
630 if (!build_cache_enabled) {
631 std::string mindir_postfix = ".mindir";
632 auto ops = output_file.find(mindir_postfix);
633 if (ops != std::string::npos && ops == output_file.size() - mindir_postfix.size()) {
634 output_file = output_file.substr(0, output_file.size() - mindir_postfix.size());
635 }
636 if (output_file.empty()) {
637 MS_LOG(ERROR) << "Converter output file cannot be empty";
638 return false;
639 }
640 build_cache_dir = output_file + "_variables";
641 }
642 if (lite::CreateDir(build_cache_dir) != RET_OK) {
643 MS_LOG(ERROR) << "Failed to create build cache dir " << build_cache_dir;
644 return false;
645 }
646 ge_options[kGeGraphCompilerCacheDir] = build_cache_dir;
647 MS_LOG(INFO) << "Update session attr " << kGeGraphCompilerCacheDir << " to " << build_cache_dir;
648 if (build_cache_dir.find(output_dir) == 0) {
649 build_cache_relative_dir_ = "./" + build_cache_dir.substr(output_dir.size());
650 }
651 return true;
652 }
653
GetGeSessionOptionsFromAscendContext(const std::map<std::string,std::string> & config,std::map<std::string,std::string> * ge_options_ptr)654 void GeGraphExecutor::GetGeSessionOptionsFromAscendContext(const std::map<std::string, std::string> &config,
655 std::map<std::string, std::string> *ge_options_ptr) {
656 MS_EXCEPTION_IF_NULL(ge_options_ptr);
657 auto &ge_options = *ge_options_ptr;
658 auto option_id = config.find(lite::kDumpPathKey);
659 if (option_id != config.end()) {
660 auto dump_path = option_id->second;
661 auto real_path = lite::RealPath(dump_path.c_str());
662 std::ifstream ifs(real_path);
663 if (!ifs.good() || !ifs.is_open()) {
664 MS_LOG(EXCEPTION) << "The dump config file: " << real_path << " is not exit or open failed.";
665 }
666 nlohmann::json dump_cfg_json;
667 try {
668 dump_cfg_json = nlohmann::json::parse(ifs);
669 } catch (const nlohmann::json::parse_error &error) {
670 MS_LOG(EXCEPTION) << "parse json failed, please check the file: " << real_path;
671 }
672 if (dump_cfg_json[kDump] != nullptr && dump_cfg_json[kDump][kDumpMode] != nullptr) {
673 ge_options["ge.exec.enableDump"] = "1";
674 ge_options["ge.exec.dumpMode"] = dump_cfg_json[kDump][kDumpMode].get<std::string>();
675 }
676 }
677 option_id = config.find(lite::kProfilingPathKey);
678 if (option_id != config.end()) {
679 auto profiling_path = option_id->second;
680 auto real_path = lite::RealPath(profiling_path.c_str());
681 std::ifstream ifs(real_path);
682 if (!ifs.good() || !ifs.is_open()) {
683 MS_LOG(EXCEPTION) << "The profiling_path config file: " << real_path << " is not exit or open failed.";
684 }
685 nlohmann::json profiling_cfg_json;
686 try {
687 profiling_cfg_json = nlohmann::json::parse(ifs);
688 } catch (const nlohmann::json::parse_error &error) {
689 MS_LOG(EXCEPTION) << "parse json failed, please check the file: " << real_path;
690 }
691 if (profiling_cfg_json[kProfiling] != nullptr) {
692 ge_options["ge.exec.profilingMode"] = "1";
693 ge_options["ge.exec.profilingOptions"] = profiling_cfg_json[kProfiling].dump();
694 }
695 }
696 option_id = config.find(lite::kGeVariableMemoryMaxSize);
697 if (option_id != config.end()) {
698 ge_options["ge.variableMemoryMaxSize"] = option_id->second;
699 }
700 option_id = config.find(lite::kGeGraphMemoryMaxSize);
701 if (option_id != config.end()) {
702 ge_options["ge.graphMemoryMaxSize"] = option_id->second;
703 }
704 option_id = config.find(lite::kGraphCompilerCacheDirKey);
705 if (option_id != config.end()) {
706 ge_options[kGeGraphCompilerCacheDir] = option_id->second;
707 }
708 }
709
GetGeGraphOptions(const FuncGraphPtr & anf_graph,std::map<std::string,std::string> * ge_options_ptr)710 void GeGraphExecutor::GetGeGraphOptions(const FuncGraphPtr &anf_graph,
711 std::map<std::string, std::string> *ge_options_ptr) {
712 MS_EXCEPTION_IF_NULL(anf_graph);
713 MS_EXCEPTION_IF_NULL(ge_options_ptr);
714 auto &ge_options = *ge_options_ptr;
715 auto ascend_device_info = GeUtils::GetAscendDeviceInfo(context_);
716 if (ascend_device_info == nullptr) {
717 MS_LOG(EXCEPTION) << "Failed to get graph session options, can not find ascend device context.";
718 }
719 uint32_t rank_id = ascend_device_info->GetRankID();
720 graph_name_ = std::to_string(rank_id) + "_" + std::to_string(global_graph_idx_) + "_" + anf_graph->ToString();
721 for (auto &c : graph_name_) {
722 if (c == '.') {
723 c = '_';
724 }
725 }
726 ge_options[kGeGraphKey] = graph_name_;
727 auto config_it = config_infos_.find(lite::kGeGraphOptionsSection);
728 if (config_it != config_infos_.end()) {
729 for (auto &item : config_it->second) {
730 ge_options[item.first] = item.second;
731 MS_LOG(INFO) << "Set ge graph option " << item.first << " to " << item.second;
732 }
733 }
734
735 auto precision_mode = ascend_device_info->GetPrecisionMode();
736 if (!precision_mode.empty()) {
737 ge_options["ge.exec.precision_mode"] = TransforPrecisionToAcl(precision_mode);
738 }
739 config_it = config_infos_.find(lite::kAscendContextSection);
740 if (config_it == config_infos_.end()) {
741 return;
742 }
743 auto config = config_it->second;
744 auto option_id = config.find(lite::kModifyMixList);
745 if (option_id != config.end()) {
746 ge_options["ge.exec.modify_mixlist"] = option_id->second;
747 }
748 }
749
GetSessionId()750 int64_t GeGraphExecutor::GetSessionId() {
751 std::string inner_group_id;
752 (void)GetConfigOption(lite::kLiteInnerGroupSection, lite::kLiteInnerGroupId, &inner_group_id);
753 if (inner_group_id.empty()) {
754 return kUnkonwnSessionId;
755 }
756 int64_t session_id = kUnkonwnSessionId;
757 if (!lite::ConvertStrToInt(inner_group_id, &session_id)) {
758 MS_LOG(WARNING) << "Failed to parse session_id " << inner_group_id << " to int64_t";
759 return kUnkonwnSessionId;
760 }
761 return session_id;
762 }
763
CreateSession(const std::map<std::string,std::string> & extra_options)764 bool GeGraphExecutor::CreateSession(const std::map<std::string, std::string> &extra_options) {
765 if (ge_session_ != nullptr) {
766 MS_LOG(INFO) << "Ge session has already been created";
767 return true;
768 }
769 session_id_ = GetSessionId();
770 (void)setenv("GE_TRAIN", "0", 1);
771 std::map<std::string, std::string> session_options = extra_options;
772 GetGeSessionOptions(&session_options);
773 if (auto option_id = session_options.find(kGeGraphCompilerCacheDir); option_id != session_options.end()) {
774 build_cache_dir_ = option_id->second;
775 }
776 session_options_ = session_options;
777 ge_session_ = GeSessionManager::CreateGeSession(session_id_, session_options);
778 if (ge_session_ == nullptr) {
779 MS_LOG(ERROR) << "Failed to create ge session";
780 return false;
781 }
782 return true;
783 }
784
AddGraph(const transform::DfGraphPtr & graph,const std::map<std::string,std::string> & options,uint32_t * graph_id_ret)785 bool GeGraphExecutor::AddGraph(const transform::DfGraphPtr &graph, const std::map<std::string, std::string> &options,
786 uint32_t *graph_id_ret) {
787 if (ge_session_ == nullptr) {
788 MS_LOG(ERROR) << "Failed to add graph, ge session cannot be nullptr";
789 return false;
790 }
791 auto graph_id = GetNextGraphIdx();
792 for (auto &option : options) {
793 MS_LOG(INFO) << "GE Graph " << graph_id << " option " << option.first << " = " << option.second;
794 }
795 auto ge_status = ge_session_->AddGraph(static_cast<uint32_t>(graph_id), *(graph), options);
796 if (ge_status != ge::GRAPH_SUCCESS) {
797 MS_LOG(ERROR) << "Call GE AddGraph Failed: " << ge::GEGetErrorMsg();
798 return false;
799 }
800 *graph_id_ret = graph_id;
801 return true;
802 }
803
GetParams(const FuncGraphPtr & anf_graph,transform::TensorOrderMap * param_tensors)804 void GeGraphExecutor::GetParams(const FuncGraphPtr &anf_graph, transform::TensorOrderMap *param_tensors) {
805 MS_EXCEPTION_IF_NULL(anf_graph);
806
807 transform::TensorOrderMap res;
808 for (auto &anf_node : anf_graph->parameters()) {
809 MS_EXCEPTION_IF_NULL(anf_node);
810 auto para = anf_node->cast<ParameterPtr>();
811 MS_EXCEPTION_IF_NULL(para);
812 if (para->has_default()) {
813 auto value = para->default_param();
814 MS_EXCEPTION_IF_NULL(value);
815 auto tensor = value->cast<std::shared_ptr<tensor::Tensor>>();
816 MS_EXCEPTION_IF_NULL(tensor);
817 auto para_name = para->name();
818 res.emplace(para_name, tensor);
819 }
820 }
821 if (session_id_ != kUnkonwnSessionId) {
822 std::vector<std::string> graph_params;
823 std::transform(res.begin(), res.end(), std::back_inserter(graph_params),
824 [](const auto &item) { return item.first; });
825 auto new_params_set = GeSessionManager::UpdateSessionVariables(session_id_, graph_params);
826 for (auto &item : res) {
827 // parameters not in new_params_set has been init by other graph
828 if (new_params_set.find(item.first) == new_params_set.end()) {
829 item.second->set_init_flag(true);
830 }
831 }
832 }
833 *param_tensors = res;
834 }
835
UpdateGraphInputs(const FuncGraphPtr & graph)836 bool GeGraphExecutor::UpdateGraphInputs(const FuncGraphPtr &graph) {
837 std::string input_shape_str;
838 std::vector<GeDynamicShapeInfo> input_shapes;
839 if (!GeDynamicUtils::GetGraphInputShapes(context_, config_infos_, &input_shapes, &input_shape_str)) {
840 MS_LOG(ERROR) << "Failed to get input shape from AscendDeviceInfo or config file";
841 return false;
842 }
843 if (input_shapes.empty()) {
844 MS_LOG(INFO) << "Not found input shape in AscendDeviceInfo or config file";
845 return true;
846 }
847 auto inputs = graph->get_inputs();
848 if (inputs.size() != input_shapes.size()) {
849 MS_LOG(ERROR) << "FuncGraph input size " << inputs.size() << " != input size " << input_shapes.size()
850 << " in AscendDeviceInfo or config file " << input_shapes.size();
851 return false;
852 }
853 for (size_t i = 0; i < input_shapes.size(); i++) {
854 auto node = inputs[i];
855 MS_CHECK_TRUE_RET(node != nullptr, false);
856 auto input_shape = input_shapes[i];
857 auto para = node->cast<ParameterPtr>();
858 if (para == nullptr) {
859 MS_LOG(ERROR) << "Cast input to Parameter failed";
860 return false;
861 }
862 MS_LOG(INFO) << "Func graph input_" << i << " " << para->name()
863 << ", shape: " << FuncGraphUtils::GetTensorShape({node, 0});
864
865 auto it = std::find_if(input_shapes.begin(), input_shapes.end(),
866 [¶](const auto &item) { return item.name == para->name(); });
867 if (it == input_shapes.end()) {
868 MS_LOG(ERROR) << "Failed to find input " << para->name() << " in input_shape " << input_shape_str;
869 return false;
870 }
871 auto abstract = para->abstract();
872 if (abstract == nullptr) {
873 MS_LOG(ERROR) << "Get input abstract failed";
874 return false;
875 }
876 ShapeVector shape;
877 std::transform(it->shape.begin(), it->shape.end(), std::back_inserter(shape), [](auto &dim) { return dim.dim; });
878 MS_LOG(INFO) << "Update shape of input_" << i << " " << para->name() << " to " << shape;
879 abstract->set_shape(std::make_shared<abstract::Shape>(shape));
880 }
881 return true;
882 }
883
InitRefDataList(const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_data_tensors)884 bool GeGraphExecutor::InitRefDataList(const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors) {
885 for (auto &item : ref_data_tensors) {
886 auto para_name = item.first;
887 auto &tensor = item.second;
888 RefDataInfo ref_data_info;
889 ref_data_info.name = para_name;
890 ref_data_info.shape = tensor->shape_c();
891 ref_data_info.dtype = tensor->data_type();
892 ref_data_info.host_data = item.second;
893 MS_LOG(INFO) << "Init ref data info[" << ref_data_infos_.size() << "] :" << ref_data_info.name
894 << ", dtype:" << ref_data_info.dtype << ", shape:" << ref_data_info.shape;
895 ref_data_infos_.push_back(ref_data_info);
896 }
897 return true;
898 }
899
InitMemoryContextManager()900 bool GeGraphExecutor::InitMemoryContextManager() {
901 auto session_context = GeSessionManager::GetGeSessionContext(session_id_);
902 if (session_context != nullptr) {
903 memory_manager_ = session_context->memory_manager.lock();
904 context_manager_ = session_context->context_manager.lock();
905 }
906 if (memory_manager_ == nullptr) {
907 memory_manager_ = std::make_shared<GeMemoryManager>();
908 if (memory_manager_ == nullptr) {
909 MS_LOG(ERROR) << "Failed to create memory manager";
910 return false;
911 }
912 if (session_context != nullptr) {
913 session_context->memory_manager = memory_manager_;
914 }
915 }
916 if (context_manager_ == nullptr) {
917 context_manager_ = std::make_shared<GeContextManager>();
918 if (context_manager_ == nullptr) {
919 MS_LOG(ERROR) << "Failed to create context manager";
920 return false;
921 }
922 if (!context_manager_->InitContext(GetDeviceID())) {
923 MS_LOG(ERROR) << "Failed to init device";
924 return false;
925 }
926 if (session_context != nullptr) {
927 session_context->context_manager = context_manager_;
928 }
929 }
930 if (!context_manager_->SetContext()) {
931 MS_LOG(ERROR) << "Failed to set ge context";
932 return false;
933 }
934 return true;
935 }
936
InitRefDataDeviceTensor()937 bool GeGraphExecutor::InitRefDataDeviceTensor() {
938 MS_LOG(INFO) << "InitRefDataDeviceTensor start.";
939 if (ref_data_infos_.empty()) {
940 MS_LOG(INFO) << "There is not ref data, no need to init ref data device data";
941 return true;
942 }
943 std::map<std::string, RefDataInfo> session_ref_data_map;
944 auto session_context = GeSessionManager::GetGeSessionContext(session_id_);
945 if (session_context != nullptr) {
946 session_ref_data_map = session_context->ref_data_map_;
947 }
948
949 size_t ref_data_total_size = 0;
950 std::map<std::string, tensor::TensorPtr> new_param_tensor_map;
951 for (size_t i = 0; i < ref_data_infos_.size(); i++) {
952 auto &item = ref_data_infos_[i];
953 auto tensor = item.host_data;
954 item.size = tensor->Size();
955 item.host_data = nullptr; // release host memory
956 ShapeVector ref_data_shape = tensor->shape_c();
957 SetRefShape(&ref_data_shape, true, item.name);
958 auto desc = transform::TransformUtil::GetGeTensorDesc(ref_data_shape, tensor->data_type(), kOpFormat_NCHW);
959 if (desc == nullptr) {
960 MS_LOG(ERROR) << "Failed to get Tensor Desc";
961 return false;
962 }
963 desc->SetPlacement(::ge::kPlacementDevice);
964 auto ret = item.ge_tensor.SetTensorDesc(*desc);
965 if (ret != ACL_ERROR_NONE) {
966 MS_LOG(ERROR) << "Failed to call ge::Tensor::SetTensorDesc, ret " << ret;
967 return false;
968 }
969 if (auto ref_it = session_ref_data_map.find(item.name); ref_it != session_ref_data_map.end()) {
970 auto &org_item = ref_it->second;
971 MS_LOG(INFO) << "Find RefData " << item.name << ", shape " << org_item.shape << ", size " << org_item.size;
972 if (org_item.size != item.size) {
973 MS_LOG(ERROR) << "RefData " << item.name << " data size != the size in pre graph, current shape " << item.shape
974 << ", size " << item.size << ", pre shape " << org_item.shape << ", pre size " << org_item.size;
975 return false;
976 }
977 auto dst_addr = ref_it->second.ge_tensor.GetData();
978 ret = item.ge_tensor.SetData(dst_addr, item.size, [](uint8_t *) -> void {});
979 if (ret != ge::GRAPH_SUCCESS) {
980 MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << item.size;
981 return false;
982 }
983 } else {
984 item.offset = ref_data_total_size;
985 ref_data_total_size += ALIGN_UP_REF_DATA(tensor->Size());
986 new_param_tensor_map[item.name] = tensor;
987 }
988 }
989 if (ref_data_total_size != 0) {
990 auto device_memory = memory_manager_->MallocDeviceMemory("RefData input", ref_data_total_size);
991 if (device_memory == nullptr) {
992 return false;
993 }
994 for (auto &item : ref_data_infos_) {
995 auto it = new_param_tensor_map.find(item.name);
996 if (it == new_param_tensor_map.end()) {
997 continue;
998 }
999 auto &tensor_val = it->second;
1000 auto dst_addr = device_memory + item.offset;
1001 if (!memory_manager_->MemcpyHost2Device(dst_addr, item.size, tensor_val->data_c(), tensor_val->Size())) {
1002 MS_LOG(ERROR) << "Failed to memory copy host data to device";
1003 return false;
1004 }
1005 auto ret = item.ge_tensor.SetData(dst_addr, item.size, [](uint8_t *) -> void {});
1006 if (ret != ge::GRAPH_SUCCESS) {
1007 MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << item.size;
1008 return false;
1009 }
1010 if (session_context != nullptr) {
1011 session_context->ref_data_map_[item.name] = item;
1012 }
1013 }
1014 }
1015 return true;
1016 }
1017
InitInOutDeviceBuffer(const std::string & name,const ShapeVector & shape,TypeId dtype,InOutBufferInfo * buffer_info)1018 bool GeGraphExecutor::InitInOutDeviceBuffer(const std::string &name, const ShapeVector &shape, TypeId dtype,
1019 InOutBufferInfo *buffer_info) {
1020 auto &info = *buffer_info;
1021 auto desc = transform::TransformUtil::GetGeTensorDesc(shape, dtype, kOpFormat_NCHW);
1022 if (desc == nullptr) {
1023 MS_LOG(ERROR) << "Failed to get Tensor Desc";
1024 return false;
1025 }
1026 auto tensor_size = SizeOf(shape) * GetDataTypeSize(dtype);
1027 if (tensor_size <= 0) {
1028 MS_LOG(INFO) << "Failed to calculate " << name << " tensor size, shape " << ShapeVectorToStr(shape)
1029 << ", date type " << dtype;
1030 return false;
1031 }
1032 desc->SetPlacement(::ge::kPlacementDevice);
1033 auto ret = info.ge_tensor.SetTensorDesc(*desc);
1034 if (ret != ACL_ERROR_NONE) {
1035 MS_LOG(ERROR) << "Failed to call ge::Tensor::SetTensorDesc, ret " << ret;
1036 return false;
1037 }
1038 info.device_addr = memory_manager_->MallocDeviceMemory(name, tensor_size);
1039 if (info.device_addr == nullptr) {
1040 MS_LOG(ERROR) << "Failed to malloc device memory for " << name << ", memory size " << tensor_size
1041 << ", tensor shape " << shape;
1042 return false;
1043 }
1044 ret = info.ge_tensor.SetData(reinterpret_cast<uint8_t *>(info.device_addr), tensor_size, [](uint8_t *) -> void {});
1045 if (ret != ge::GRAPH_SUCCESS) {
1046 MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(uint8_t*, size, DeleteFunc), data size " << tensor_size;
1047 return false;
1048 }
1049 info.max_size = tensor_size;
1050 info.shape = shape;
1051 info.dtype = dtype;
1052 return true;
1053 }
1054
UpdateInputShapeOption(const FuncGraphPtr & func_graph,const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_data_tensors,std::map<std::string,std::string> * ge_options_ptr)1055 bool GeGraphExecutor::UpdateInputShapeOption(
1056 const FuncGraphPtr &func_graph, const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors,
1057 std::map<std::string, std::string> *ge_options_ptr) {
1058 if (ge_options_ptr == nullptr) {
1059 MS_LOG(ERROR) << "Input argument ge_options_ptr cannot be nullptr";
1060 return false;
1061 }
1062 std::string input_shape_str;
1063 std::vector<GeDynamicShapeInfo> input_shapes;
1064 if (!GeDynamicUtils::GetGraphInputShapes(context_, config_infos_, &input_shapes, &input_shape_str)) {
1065 MS_LOG(ERROR) << "Failed to get input shape from AscendDeviceInfo or config file";
1066 return false;
1067 }
1068 std::map<std::string, std::string> shape_map;
1069 if (input_shapes.empty()) {
1070 MS_LOG(INFO) << "Not found input shape in AscendDeviceInfo or config file";
1071 if (!dyn_kv_cache_info_.dynamic_kv_cache) {
1072 return true;
1073 }
1074 auto inputs = func_graph->get_inputs();
1075 bool dyn_input = false;
1076 for (auto &item : inputs) {
1077 auto shape = FuncGraphUtils::GetTensorShape({item, 0});
1078 if (std::any_of(shape.begin(), shape.end(), [](auto dim) { return dim < 0; })) {
1079 dyn_input = true;
1080 }
1081 shape_map[item->fullname_with_scope()] = lite::VectorToStrJoin(shape, ",");
1082 }
1083 if (!dyn_input) {
1084 MS_LOG(INFO) << "Current model has no dynamic inputs and there is no ge.inputShape set in config, skip update "
1085 "ge.inputShape option for dynamic KVCache";
1086 return true;
1087 }
1088 } else {
1089 for (auto &item : input_shapes) {
1090 shape_map[item.name] = item.shape_str;
1091 }
1092 }
1093 for (auto &item : ref_data_tensors) {
1094 ShapeVector ref_dyn_shape = item.second->shape_c();
1095 SetRefShape(&ref_dyn_shape, true, item.first);
1096 shape_map[item.first] = lite::VectorToStrJoin(ref_dyn_shape, ",");
1097 }
1098 std::string new_input_shape_str = lite::MapToStrJoin(shape_map, ":", ";");
1099 GeDynamicUtils::UpdateGraphInputShapes(context_, &config_infos_, new_input_shape_str);
1100 (*ge_options_ptr)["ge.inputShape"] = new_input_shape_str;
1101 MS_LOG(INFO) << "Update ge.inputShape to " << new_input_shape_str;
1102 return true;
1103 }
1104
InitRefDataContext(const FuncGraphPtr & func_graph,const std::vector<std::pair<std::string,tensor::TensorPtr>> & ref_data_tensors,std::map<std::string,std::string> * ge_options_ptr)1105 bool GeGraphExecutor::InitRefDataContext(const FuncGraphPtr &func_graph,
1106 const std::vector<std::pair<std::string, tensor::TensorPtr>> &ref_data_tensors,
1107 std::map<std::string, std::string> *ge_options_ptr) {
1108 if (!UpdateInputShapeOption(func_graph, ref_data_tensors, ge_options_ptr)) {
1109 MS_LOG(ERROR) << "Failed to update input shape option";
1110 return false;
1111 }
1112 if (!InitRefDataList(ref_data_tensors)) {
1113 MS_LOG(ERROR) << "Failed to init ref data list";
1114 return false;
1115 }
1116 if (!InitMaxShapeParam()) {
1117 MS_LOG(ERROR) << "Failed to init max shape size";
1118 return false;
1119 }
1120 return true;
1121 }
1122
CreateFakeGraph(const std::map<std::string,std::string> & ge_options)1123 transform::DfGraphPtr GeGraphExecutor::CreateFakeGraph(const std::map<std::string, std::string> &ge_options) {
1124 if (enable_update_weight_) {
1125 MS_LOG(INFO) << "Enable update weight, skip create small ge graph";
1126 return nullptr;
1127 }
1128 if (build_cache_dir_.empty()) {
1129 MS_LOG(INFO) << "Option model_cache_mode " << cache_mode_ << " is not mem_opt and not load offline model or "
1130 << kGeGraphCompilerCacheDir << " is empty, skip create small ge graph";
1131 return nullptr;
1132 }
1133 auto graph_it = ge_options.find(kGeGraphKey);
1134 if (graph_it == ge_options.end()) {
1135 MS_LOG(INFO) << "Cannot find option " << kGeGraphKey << ", skip create small ge graph";
1136 return nullptr;
1137 }
1138 auto graph_key = graph_it->second;
1139 auto idx_file_name = build_cache_dir_ + "/" + graph_key + ".idx";
1140 if (!UpdateOmCacheIdxFile(idx_file_name)) {
1141 return nullptr;
1142 }
1143 auto df_graph = GenExampleGraph(graph_key);
1144 if (df_graph == nullptr) {
1145 MS_LOG(WARNING) << "Failed to create small ge graph for graph " << graph_key << ", skip create small ge graph";
1146 return nullptr;
1147 }
1148 MS_LOG(INFO) << "Create small ge graph for graph " << graph_key;
1149 return df_graph;
1150 }
1151
UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> & weights)1152 bool GeGraphExecutor::UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> &weights) {
1153 auto time1 = lite::GetTimeUs();
1154 if (init_graph_id_list_.empty()) {
1155 MS_LOG(ERROR) << "init graph id list is empty.";
1156 return false;
1157 }
1158 uint32_t init_graph_id = init_graph_id_list_[0];
1159 MS_LOG(INFO) << "init_graph_id: " << init_graph_id;
1160 if (update_weight_ptr_ == nullptr) {
1161 MS_LOG(ERROR) << "please init update weight class by build model.";
1162 return false;
1163 }
1164 std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> new_weight_tensors;
1165 auto ret = update_weight_ptr_->UpdateConstantTensorData(weights, &new_weight_tensors);
1166 if (!ret) {
1167 MS_LOG(ERROR) << "update weight failed.";
1168 return false;
1169 }
1170 MS_LOG(DEBUG) << "ExecInitGraph start.";
1171 auto time2 = lite::GetTimeUs();
1172 MS_LOG(INFO) << "update weight prepare time: " << (time2 - time1) / kNumMicrosecondToMillisecond << " ms";
1173
1174 // cppcheck-suppress cppcheckError
1175 for (size_t i = 0; i < new_weight_tensors.size(); i++) {
1176 std::vector<::ge::Tensor> ge_inputs;
1177 // cppcheck-suppress cppcheckError
1178 for (size_t j = 0; j < new_weight_tensors[i].size(); j++) {
1179 auto &input = new_weight_tensors[i][j];
1180 auto ge_tensor = transform::TransformUtil::ConvertTensor(input, kOpFormat_NCHW, false);
1181 if (ge_tensor == nullptr) {
1182 MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1183 return false;
1184 }
1185 ge_inputs.emplace_back(*ge_tensor);
1186 }
1187 std::vector<::ge::Tensor> ge_outputs;
1188 auto ge_status = ge_session_->RunGraph(init_graph_id, ge_inputs, ge_outputs);
1189 if (ge_status != ge::GRAPH_SUCCESS) {
1190 MS_LOG(ERROR) << "Exec init graph failed, graph id " << init_graph_id;
1191 return false;
1192 }
1193 }
1194 auto time3 = lite::GetTimeUs();
1195 MS_LOG(INFO) << "update weight run init graph time: " << (time3 - time2) / kNumMicrosecondToMillisecond << " ms";
1196 return true;
1197 }
1198
CreateGeGraphOnline(const FuncGraphPtr & anf_graph,std::map<std::string,std::string> * ge_options_ptr)1199 transform::DfGraphPtr GeGraphExecutor::CreateGeGraphOnline(const FuncGraphPtr &anf_graph,
1200 std::map<std::string, std::string> *ge_options_ptr) {
1201 std::vector<std::string> extra_variables_names = {};
1202 if (enable_update_weight_ && update_weight_ptr_ != nullptr) {
1203 auto ret = update_weight_ptr_->CreateAddOpNodeForGraph(anf_graph);
1204 if (!ret) {
1205 MS_LOG(ERROR) << "CreateAddOpNodeForGraph failed.";
1206 return nullptr;
1207 }
1208 extra_variables_names = update_weight_ptr_->GetVariableParamsName(anf_graph);
1209 if (extra_variables_names.empty()) {
1210 MS_LOG(WARNING) << "GetVariableParamsName failed.";
1211 return nullptr;
1212 }
1213 }
1214 transform::TensorOrderMap params_vals;
1215 GetParams(anf_graph, ¶ms_vals);
1216 transform::SetDynRefDataFunc dyn_ref_data_func = nullptr;
1217 if (dyn_kv_cache_info_.dynamic_kv_cache) {
1218 dyn_ref_data_func = [this](const AnfNodePtr &node, const ShapeVector &org_shape) -> ShapeVector {
1219 return SetKVCacheShape(dyn_kv_cache_info_.batch_size_dyn, dyn_kv_cache_info_.seq_length_dyn,
1220 dyn_kv_cache_info_.kv_cache_layout, org_shape);
1221 };
1222 }
1223
1224 MS_LOG(INFO) << "extra_variables_names size: " << extra_variables_names.size();
1225 auto converter = std::make_shared<transform::DfGraphConvertor>(anf_graph, "", ref_mode_flag_, extra_variables_names,
1226 dyn_ref_data_func);
1227 transform::BuildGraph(graph_name_, converter, params_vals);
1228 auto err_code = transform::ErrCode(converter);
1229 if (err_code != 0) {
1230 transform::ClearGraph();
1231 MS_LOG(ERROR) << "Convert df graph failed, err:" << err_code;
1232 return nullptr;
1233 }
1234 auto init_graph = transform::GetInitGraph(converter);
1235 if (init_graph != nullptr) {
1236 uint32_t init_graph_id = 0;
1237 if (!AddGraph(init_graph, {}, &init_graph_id)) {
1238 MS_LOG(ERROR) << "Failed to add init graph, graph name " << anf_graph->ToString();
1239 return nullptr;
1240 }
1241 if (enable_update_weight_ && update_weight_ptr_ != nullptr) {
1242 init_graph_id_list_.push_back(init_graph_id);
1243 }
1244 auto init_data_names = converter->GetInitDataNames();
1245 if (enable_update_weight_ && update_weight_ptr_ != nullptr) {
1246 if (!update_weight_ptr_->SetInitDataNames(init_data_names)) {
1247 MS_LOG(ERROR) << "set init data name failed.";
1248 return nullptr;
1249 }
1250 }
1251 // copy init weight to device
1252 if (!RunGeInitGraph(init_graph_id, init_data_names, params_vals)) {
1253 MS_LOG(ERROR) << "Failed to run init graph for " << anf_graph->ToString();
1254 return nullptr;
1255 }
1256 if (!enable_update_weight_) {
1257 ge_session_->RemoveGraph(init_graph_id);
1258 }
1259 } else {
1260 MS_LOG(INFO) << "There is no init graph for graph " << anf_graph->ToString();
1261 }
1262 if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
1263 auto ref_data_names = converter->GetRefDataNames();
1264 std::vector<std::pair<std::string, tensor::TensorPtr>> ref_datas;
1265 std::transform(ref_data_names.begin(), ref_data_names.end(), std::back_inserter(ref_datas),
1266 [¶ms_vals](auto &item) { return std::make_pair(item, params_vals.at(item)); });
1267 if (!InitRefDataContext(anf_graph, ref_datas, ge_options_ptr)) {
1268 MS_LOG(ERROR) << "Failed to init refdata context";
1269 return nullptr;
1270 }
1271 }
1272 auto df_graph = transform::GetComputeGraph(converter);
1273 return df_graph;
1274 }
1275
SetOptionsIntoOfflineModel(const std::map<std::string,std::string> & graph_options,std::map<std::string,ValuePtr> * attr_map_ptr)1276 void GeGraphExecutor::SetOptionsIntoOfflineModel(const std::map<std::string, std::string> &graph_options,
1277 std::map<std::string, ValuePtr> *attr_map_ptr) {
1278 auto &attr_map = *attr_map_ptr;
1279
1280 if (!build_cache_relative_dir_.empty()) {
1281 attr_map[lite::kNameAttrWeightDir] = MakeValue(build_cache_relative_dir_);
1282 MS_LOG(INFO) << "Set graph attr " << lite::kNameAttrWeightDir << " to " << build_cache_relative_dir_;
1283 }
1284 // ge session options
1285 auto find_set_option = [](const std::map<std::string, std::string> &from_options,
1286 std::vector<std::string> *to_options, const std::string &option) {
1287 auto config_it = from_options.find(option);
1288 if (config_it != from_options.end()) {
1289 to_options->push_back(option);
1290 to_options->push_back(config_it->second);
1291 }
1292 };
1293 std::vector<std::string> session_save_options;
1294 find_set_option(session_options_, &session_save_options, "ge.externalWeight");
1295 attr_map[lite::kGeSessionOptionsSection] = MakeValue(session_save_options);
1296
1297 std::vector<std::string> graph_save_options;
1298 find_set_option(graph_options, &graph_save_options, "ge.inputShape");
1299 find_set_option(graph_options, &graph_save_options, "ge.dynamicDims");
1300 find_set_option(graph_options, &graph_save_options, "ge.dynamicNodeType");
1301 attr_map[lite::kGeGraphOptionsSection] = MakeValue(graph_save_options);
1302 }
1303
LoadOnlineGraph(const FuncGraphPtr & anf_graph,uint32_t * graph_id)1304 bool GeGraphExecutor::LoadOnlineGraph(const FuncGraphPtr &anf_graph, uint32_t *graph_id) {
1305 std::map<std::string, std::string> extra_session_options;
1306 if (!cache_mode_.empty()) {
1307 if (!SetModelCacheDir(&extra_session_options)) {
1308 return false;
1309 }
1310 }
1311 if (!CreateSession(extra_session_options)) {
1312 MS_LOG(ERROR) << "Failed to create ge session";
1313 return false;
1314 }
1315 std::map<std::string, std::string> ge_options;
1316 GetGeGraphOptions(anf_graph, &ge_options);
1317 auto df_graph = CompileGraphCommon(anf_graph, &ge_options);
1318 if (df_graph == nullptr) {
1319 MS_LOG(ERROR) << "Input param graph is nullptr.";
1320 return false;
1321 }
1322 if (cache_mode_ == "mem_opt") {
1323 auto fake_df_graph = CreateFakeGraph(ge_options);
1324 if (fake_df_graph != nullptr) {
1325 df_graph = fake_df_graph;
1326 }
1327 }
1328 if (!AddGraph(df_graph, ge_options, graph_id)) {
1329 MS_LOG(ERROR) << "Failed to add compute graph, graph name " << anf_graph->ToString();
1330 return false;
1331 }
1332 return true;
1333 }
1334
CompileGraphCommon(const FuncGraphPtr & anf_graph,std::map<std::string,std::string> * ge_options_ptr)1335 transform::DfGraphPtr GeGraphExecutor::CompileGraphCommon(const FuncGraphPtr &anf_graph,
1336 std::map<std::string, std::string> *ge_options_ptr) {
1337 if (anf_graph == nullptr) {
1338 MS_LOG(ERROR) << "Input param graph is nullptr.";
1339 return nullptr;
1340 }
1341 if (ge_options_ptr == nullptr) {
1342 MS_LOG(ERROR) << "Input param ge_options_ptr is nullptr.";
1343 return nullptr;
1344 }
1345
1346 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
1347 auto param = ParseGraphKernelConfigs(config_infos_);
1348 if (param != nullptr) {
1349 auto rank_id = common::GetEnv("RANK_ID");
1350 if (rank_id.empty()) {
1351 auto ascend_device_info = GeUtils::GetAscendDeviceInfo(context_);
1352 if (ascend_device_info != nullptr) {
1353 auto rank_id_value = ascend_device_info->GetRankID();
1354 common::SetEnv("RANK_ID", std::to_string(rank_id_value).c_str());
1355 }
1356 }
1357 if (GraphKernelOptimize(anf_graph, param) != lite::RET_OK) {
1358 MS_LOG(ERROR) << "Run graphkernel optimization failed.";
1359 return nullptr;
1360 }
1361 }
1362 #endif
1363
1364 auto remove_load_pass = std::make_shared<opt::RemoveLoadPass>();
1365 remove_load_pass->Run(anf_graph);
1366
1367 if (!UpdateGraphInputs(anf_graph)) {
1368 MS_LOG(ERROR) << "Failed to update graph inputs";
1369 return nullptr;
1370 }
1371
1372 opt::UpdateManager(anf_graph);
1373
1374 // Convert mindir attributes to inputs because of dynamic_shape operator.
1375 // For the transformed operators, the GE adapter only supports inputs but not attributes.
1376 auto args_to_attr_pass = std::make_shared<opt::AttrToArgsPass>();
1377 if (args_to_attr_pass == nullptr) {
1378 MS_LOG(ERROR) << "create AttrToArgsPass failed";
1379 return nullptr;
1380 }
1381 if (!args_to_attr_pass->Run(anf_graph)) {
1382 MS_LOG(ERROR) << "convert args to attr pass failed";
1383 return nullptr;
1384 }
1385
1386 transform::DfGraphPtr df_graph = nullptr;
1387 auto func_type = anf_graph->get_attr(kAttrFuncType);
1388 is_data_flow_graph_ = func_type != nullptr && GetValue<std::string>(func_type) == kDataFlowGraphType;
1389 if (!is_data_flow_graph_) {
1390 df_graph = CreateGeGraphOnline(anf_graph, ge_options_ptr);
1391 } else {
1392 df_graph = GetDataFlowGraph(anf_graph, *ge_options_ptr);
1393 }
1394 return df_graph;
1395 }
1396
CompileGraph(const FuncGraphPtr & anf_graph,const std::map<string,string> &,uint32_t * graph_id)1397 bool GeGraphExecutor::CompileGraph(const FuncGraphPtr &anf_graph, const std::map<string, string> &,
1398 uint32_t *graph_id) {
1399 MS_CHECK_TRUE_RET(graph_id != nullptr, false);
1400 uint32_t compute_graph_id = 0;
1401 if (CustomAscendUtils::IsCustomFuncGraph(anf_graph)) {
1402 MS_LOG(ERROR) << "Offline converted MindIR is not supported currently";
1403 return false;
1404 } else {
1405 auto ret = LoadOnlineGraph(anf_graph, &compute_graph_id);
1406 if (!ret) {
1407 MS_LOG(ERROR) << "Failed to load online model";
1408 return false;
1409 }
1410 }
1411 compute_graph_id_list_.push_back(compute_graph_id);
1412 *graph_id = compute_graph_id;
1413 if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
1414 if (!BuildGraphRefMode(anf_graph, compute_graph_id)) {
1415 MS_LOG(ERROR) << "Failed to build ge graph with refdata";
1416 return false;
1417 }
1418 }
1419 std::vector<tensor::TensorPtr> orig_output;
1420 std::vector<std::string> output_names;
1421 FuncGraphUtils::GetFuncGraphOutputsInfo(anf_graph, &orig_output, &output_names);
1422 original_graph_outputs_[*graph_id] = orig_output;
1423 return true;
1424 }
1425
GetOneRealInputs(const FuncGraphPtr & anf_graph,std::vector<ge::Tensor> * ge_tensors_ptr)1426 bool GeGraphExecutor::GetOneRealInputs(const FuncGraphPtr &anf_graph, std::vector<ge::Tensor> *ge_tensors_ptr) {
1427 std::vector<std::pair<std::string, ShapeVector>> input_shapes_configs;
1428 std::string input_shape_str;
1429 if (!GeDynamicUtils::GetGraphOneRealShapes(context_, config_infos_, &input_shapes_configs, &input_shape_str)) {
1430 MS_LOG(ERROR) << "Failed to get one real input shape";
1431 return false;
1432 }
1433 std::vector<tensor::TensorPtr> inputs;
1434 std::vector<std::string> input_names;
1435 FuncGraphUtils::GetFuncGraphInputsInfo(anf_graph, &inputs, &input_names);
1436 if (!input_shapes_configs.empty() && input_shapes_configs.size() != inputs.size()) {
1437 MS_LOG(ERROR) << "Input count " << input_shapes_configs.size()
1438 << " get from input_shape of AscendDeviceInfo or config file != input count " << inputs.size()
1439 << " got from graph";
1440 return false;
1441 }
1442 std::vector<::ge::Tensor> ge_inputs;
1443 // cppcheck-suppress cppcheckError
1444 for (size_t i = 0; i < inputs.size(); i++) {
1445 auto &input = inputs[i];
1446 auto input_name = input_names[i];
1447 if (!input_shapes_configs.empty()) {
1448 auto it = std::find_if(input_shapes_configs.begin(), input_shapes_configs.end(),
1449 [&input_name](const auto &item) { return input_name == item.first; });
1450 if (it == input_shapes_configs.end()) {
1451 MS_LOG(ERROR) << "Cannot find input " << input_name << " in input_shape " << input_shape_str;
1452 return false;
1453 }
1454 input = std::make_shared<tensor::Tensor>(input->data_type(), it->second);
1455 } else if (GeDynamicUtils::IsDynamicInputShapes({input->shape_c()})) {
1456 MS_LOG(ERROR) << "Input " << i << " is dynamic shape " << input->shape_c()
1457 << ", but there is no input shape specified in AscendDeviceInfo or config file";
1458 return false;
1459 }
1460 MS_LOG(INFO) << "Input " << i << " shape " << input->shape_c() << ", datatype " << input->data_type();
1461 auto ge_tensor = transform::TransformUtil::ConvertTensor(input, kOpFormat_NCHW);
1462 if (ge_tensor == nullptr) {
1463 MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1464 return false;
1465 }
1466 ge_inputs.emplace_back(*ge_tensor);
1467 }
1468 *ge_tensors_ptr = ge_inputs;
1469 return true;
1470 }
1471
AoeTuning(const FuncGraphPtr & anf_graph)1472 bool GeGraphExecutor::AoeTuning(const FuncGraphPtr &anf_graph) {
1473 if (!CreateSession({})) {
1474 MS_LOG(ERROR) << "Failed to create ge session";
1475 return false;
1476 }
1477 std::map<std::string, std::string> ge_options;
1478 GetGeGraphOptions(anf_graph, &ge_options);
1479 auto df_graph = CompileGraphCommon(anf_graph, &ge_options);
1480 if (df_graph == nullptr) {
1481 MS_LOG(ERROR) << "Input param graph is nullptr.";
1482 return false;
1483 }
1484 std::vector<::ge::Tensor> ge_inputs;
1485 if (!GetOneRealInputs(anf_graph, &ge_inputs)) {
1486 MS_LOG(ERROR) << "Failed to get one real inputs";
1487 return false;
1488 }
1489 AoeApiTuning tuning;
1490 auto status = tuning.AoeTurningGraph(ge_session_, df_graph, ge_inputs, context_, config_infos_);
1491 if (status != kSuccess) {
1492 MS_LOG(ERROR) << "Failed to call AoeTurningGraph";
1493 return false;
1494 }
1495 return true;
1496 }
1497
RunGeInitGraph(uint32_t init_graph_id,const std::vector<std::string> & init_data_names,const transform::TensorOrderMap & params_vals)1498 bool GeGraphExecutor::RunGeInitGraph(uint32_t init_graph_id, const std::vector<std::string> &init_data_names,
1499 const transform::TensorOrderMap ¶ms_vals) {
1500 std::vector<tensor::TensorPtr> init_data_tensors;
1501 for (auto &item : init_data_names) {
1502 auto it = params_vals.find(item);
1503 if (it == params_vals.end()) {
1504 MS_LOG(ERROR) << "Cannot find parameter " << item << " in parameter map";
1505 return false;
1506 }
1507 init_data_tensors.push_back(it->second);
1508 }
1509 MS_LOG(DEBUG) << "ExecInitGraph start.";
1510 std::vector<::ge::Tensor> ge_inputs;
1511 for (size_t i = 0; i < init_data_tensors.size(); i++) {
1512 auto &input = init_data_tensors[i];
1513 auto ge_tensor = transform::TransformUtil::ConvertTensor(input, kOpFormat_NCHW, false);
1514 if (ge_tensor == nullptr) {
1515 MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1516 return false;
1517 }
1518 ge_inputs.emplace_back(*ge_tensor);
1519 }
1520 std::vector<::ge::Tensor> ge_outputs;
1521 auto ge_status = ge_session_->RunGraph(init_graph_id, ge_inputs, ge_outputs);
1522 if (ge_status != ge::GRAPH_SUCCESS) {
1523 MS_LOG(ERROR) << "Exec init graph failed, graph id " << init_graph_id;
1524 return false;
1525 }
1526 MS_LOG(INFO) << "Exec init graph success, graph id " << init_graph_id;
1527 return true;
1528 }
1529
RunGeGraphAsync(uint32_t graph_id,const std::vector<::ge::Tensor> & inputs,std::vector<::ge::Tensor> * outputs)1530 bool GeGraphExecutor::RunGeGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs,
1531 std::vector<::ge::Tensor> *outputs) {
1532 bool is_finished = false;
1533 bool end_of_sequence = false;
1534 std::promise<void> promise;
1535 auto call_back = [outputs, &is_finished, &end_of_sequence, &promise](ge::Status ge_status,
1536 const std::vector<ge::Tensor> &ge_outputs) {
1537 if (ge_status == ge::GRAPH_SUCCESS) {
1538 *outputs = ge_outputs;
1539 is_finished = true;
1540 } else if (ge_status == ge::END_OF_SEQUENCE) {
1541 MS_LOG(ERROR) << "RunAsync out of range: End of sequence.";
1542 end_of_sequence = true;
1543 } else {
1544 MS_LOG(ERROR) << "RunAsync failed." << ge::GEGetErrorMsg();
1545 }
1546 promise.set_value();
1547 return;
1548 };
1549 if (ge_session_ == nullptr) {
1550 MS_LOG(ERROR) << "The GE session is null, can't run the graph!";
1551 return false;
1552 }
1553 ge::Status ret = ge_session_->RunGraphAsync(graph_id, inputs, call_back);
1554 if (ret != ge::GRAPH_SUCCESS) {
1555 MS_LOG(ERROR) << "Call GE RunGraphAsync Failed: " << ge::GEGetErrorMsg();
1556 return false;
1557 }
1558 auto future = promise.get_future();
1559 future.wait();
1560 if (end_of_sequence) {
1561 MS_LOG(ERROR) << "Failed to call GE RunGraphAsync: End of sequence";
1562 return false;
1563 }
1564 return is_finished;
1565 }
1566
RunDataFlowGraphAsync(uint32_t graph_id,const std::vector<::ge::Tensor> & inputs,std::vector<::ge::Tensor> * outputs)1567 bool GeGraphExecutor::RunDataFlowGraphAsync(uint32_t graph_id, const std::vector<::ge::Tensor> &inputs,
1568 std::vector<::ge::Tensor> *outputs) {
1569 ge::DataFlowInfo data_flow_info;
1570 int time_out = 3000; // set the timeout to 3000s.
1571 auto ret = ge_session_->FeedDataFlowGraph(graph_id, inputs, data_flow_info, time_out);
1572 if (ret != ge::SUCCESS) {
1573 MS_LOG(ERROR) << "Feed input data failed.";
1574 return false;
1575 }
1576 ret = ge_session_->FetchDataFlowGraph(graph_id, *outputs, data_flow_info, time_out);
1577 if (ret != ge::SUCCESS) {
1578 MS_LOG(ERROR) << "Fetch output data failed.";
1579 return false;
1580 }
1581 return true;
1582 }
1583
InitInputDataTensor(const std::vector<tensor::Tensor> & inputs,std::vector<::ge::Tensor> * ge_inputs,std::vector<::ge::Tensor> * ge_outputs)1584 bool GeGraphExecutor::InitInputDataTensor(const std::vector<tensor::Tensor> &inputs,
1585 std::vector<::ge::Tensor> *ge_inputs, std::vector<::ge::Tensor> *ge_outputs) {
1586 if (inputs_buffer_infos_.size() != inputs.size()) {
1587 MS_LOG(ERROR) << "Input data info size " << inputs_buffer_infos_.size() << " != inputs size " << inputs.size();
1588 return false;
1589 }
1590 if (memory_manager_ == nullptr) {
1591 MS_LOG(ERROR) << "Memory manager or context manager is nullptr";
1592 return false;
1593 }
1594 for (size_t i = 0; i < inputs.size(); i++) {
1595 auto &input = inputs[i];
1596 MS_LOG(INFO) << "Input " << i << " shape " << tensor::ShapeToString(input.shape_c()) << ", datatype "
1597 << input.data_type();
1598 auto tensor_size = input.Size();
1599 auto &input_info = inputs_buffer_infos_[i];
1600 if (input_info.max_size < tensor_size) {
1601 MS_LOG(ERROR) << "Input " << i << " data size invalid, graph size " << input_info.max_size << ", given size "
1602 << tensor_size;
1603 return false;
1604 }
1605 if (!memory_manager_->MemcpyHost2Device(input_info.device_addr, input_info.max_size, input.data_c(), tensor_size)) {
1606 return false;
1607 }
1608
1609 SetGeTensorShape(&input_info.ge_tensor, input.shape_c());
1610 ge_inputs->push_back(input_info.ge_tensor);
1611 }
1612 for (auto &item : ref_data_infos_) {
1613 if (dyn_kv_cache_info_.dynamic_kv_cache) {
1614 ShapeVector ref_real_shape = transform::TransformUtil::ConvertGeShape(item.ge_tensor.GetTensorDesc().GetShape());
1615 SetRefShape(&ref_real_shape, false, item.name);
1616 SetGeTensorShape(&item.ge_tensor, ref_real_shape);
1617 MS_LOG(INFO) << "Update RefData Input " << item.name << " shape to " << tensor::ShapeToString(ref_real_shape);
1618 }
1619 ge_inputs->push_back(item.ge_tensor);
1620 }
1621 if (!dyn_kv_cache_info_.is_ge_graph_static_) {
1622 ge_outputs->resize(outputs_buffer_infos_.size());
1623 for (auto &ge_tensor : *ge_outputs) {
1624 auto ret = ge_tensor.SetData(nullptr, 0U, [](uint8_t *) -> void {});
1625 if (ret != ge::GRAPH_SUCCESS) {
1626 MS_LOG(ERROR) << "Failed to call ge::Tensor SetData(nullptr, 0, DeleteFunc) for output";
1627 return false;
1628 }
1629 }
1630 } else {
1631 for (auto &output : outputs_buffer_infos_) {
1632 ge_outputs->push_back(output.ge_tensor);
1633 }
1634 }
1635 return true;
1636 }
1637
BuildGraphRefMode(const FuncGraphPtr & anf_graph,uint32_t graph_id)1638 bool GeGraphExecutor::BuildGraphRefMode(const FuncGraphPtr &anf_graph, uint32_t graph_id) {
1639 MS_LOG(INFO) << "Call GE CompileGraph start, graph id " << graph_id;
1640 ge::Status ret = ge_session_->CompileGraph(graph_id);
1641 if (ret != ge::GRAPH_SUCCESS) {
1642 MS_LOG(ERROR) << "Call GE CompileGraph Failed: " << ge::GEGetErrorMsg();
1643 return false;
1644 }
1645 MS_LOG(INFO) << "Call GE CompileGraph end, graph id " << graph_id;
1646 if (!InitMemoryContextManager()) {
1647 return false;
1648 }
1649
1650 if (!InitRefDataDeviceTensor()) {
1651 MS_LOG(ERROR) << "Failed to init ref data device data";
1652 return false;
1653 }
1654
1655 // ref data input memories have been allocated
1656 // for input data memory
1657 if (!InitInputDeviceTensor(anf_graph)) {
1658 MS_LOG(ERROR) << "Failed to init input data device data";
1659 return false;
1660 }
1661
1662 // for output memory
1663 if (!InitOutputDeviceTensor(anf_graph, graph_id)) {
1664 MS_LOG(ERROR) << "Failed to init input data device data";
1665 return false;
1666 }
1667 return true;
1668 }
1669
RunGraphRefMode(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs)1670 bool GeGraphExecutor::RunGraphRefMode(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs,
1671 std::vector<tensor::Tensor> *outputs) {
1672 MS_LOG(INFO) << "RunGraphRefMode begin";
1673 std::vector<::ge::Tensor> ge_inputs;
1674 std::vector<::ge::Tensor> ge_outputs;
1675 if (!InitRealShapeParam(inputs)) {
1676 return false;
1677 }
1678 if (!InitInputDataTensor(inputs, &ge_inputs, &ge_outputs)) {
1679 MS_LOG(ERROR) << "Init input tensor failed in run graph.";
1680 return false;
1681 }
1682 auto stream = context_manager_->GetDefaultStream();
1683 if (!RunGraphWithStreamAsync(graph_id, stream, ge_inputs, &ge_outputs)) {
1684 MS_LOG(ERROR) << "Failed in run graph with stream async.";
1685 return false;
1686 }
1687 if (!SyncDeviceOutputsToHost(outputs, &ge_outputs)) {
1688 MS_LOG(ERROR) << "Failed in sync device output to host.";
1689 return false;
1690 }
1691 MS_LOG(INFO) << "RunGraphRefMode end";
1692 return true;
1693 }
1694
SyncDeviceOutputsToHost(std::vector<tensor::Tensor> * outputs,std::vector<::ge::Tensor> * ge_outputs)1695 bool GeGraphExecutor::SyncDeviceOutputsToHost(std::vector<tensor::Tensor> *outputs,
1696 std::vector<::ge::Tensor> *ge_outputs) {
1697 UpdateOutputShapeInfo(ge_outputs);
1698
1699 size_t output_size = outputs_buffer_infos_.size();
1700 if (!outputs->empty()) {
1701 if (outputs->size() != output_size) {
1702 MS_LOG(ERROR) << "Invalid output size, outputs' size " << outputs->size() << "ge tensor size " << output_size;
1703 return false;
1704 }
1705 // cppcheck-suppress cppcheckError
1706 for (size_t i = 0; i < output_size; ++i) {
1707 auto &output_info = outputs_buffer_infos_[i];
1708 auto &output = (*outputs)[i];
1709 if (output.Size() < output_info.max_size) {
1710 MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.Size()
1711 << " is less than actual output size " << output_info.max_size;
1712 }
1713 if ((*outputs)[i].data_c() == nullptr) {
1714 MS_LOG(ERROR) << "Output data ptr is nullptr.";
1715 return false;
1716 }
1717 auto mem_ret = memory_manager_->MemcpyDevice2Host(reinterpret_cast<uint8_t *>(output.data_c()), output.Size(),
1718 output_info.device_addr, output_info.max_size);
1719 if (!mem_ret) {
1720 MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.Size()
1721 << ", src size: " << output_info.max_size;
1722 return false;
1723 }
1724 MS_LOG(INFO) << "Output " << i << " shape " << tensor::ShapeToString(output_info.shape) << ", datatype "
1725 << output_info.dtype;
1726 }
1727 } else {
1728 for (size_t i = 0; i < output_size; i++) {
1729 auto &output_info = outputs_buffer_infos_[i];
1730 tensor::Tensor ms_tensor(output_info.dtype, output_info.shape);
1731 auto mem_ret =
1732 memory_manager_->MemcpyDevice2Host(reinterpret_cast<uint8_t *>(ms_tensor.data_c()), ms_tensor.Size(),
1733 output_info.device_addr, output_info.max_size);
1734 if (!mem_ret) {
1735 MS_LOG(ERROR) << "Failed to copy output data, dst size: " << ms_tensor.Size()
1736 << ", src size: " << output_info.max_size;
1737 return false;
1738 }
1739 MS_LOG(INFO) << "Output " << i << " shape " << tensor::ShapeToString(output_info.shape) << ", datatype "
1740 << output_info.dtype;
1741 outputs->push_back(ms_tensor);
1742 }
1743 }
1744 return true;
1745 }
1746
RunGraphWithStreamAsync(uint32_t graph_id,void * stream,const std::vector<GeTensor> & inputs,std::vector<GeTensor> * outputs)1747 bool GeGraphExecutor::RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector<GeTensor> &inputs,
1748 std::vector<GeTensor> *outputs) {
1749 MS_EXCEPTION_IF_NULL(outputs);
1750 for (auto ge_input : inputs) {
1751 MS_LOG(INFO) << "In ge graph " << graph_id << ", input for RunGraphWithStreamAsync : "
1752 << tensor::ShapeToString(
1753 transform::TransformUtil::ConvertGeShape(ge_input.GetTensorDesc().GetShape()));
1754 }
1755 MS_LOG(INFO) << "Run the graph in GE with " << inputs.size() << " inputs";
1756 struct timeval start_time;
1757 (void)gettimeofday(&start_time, nullptr);
1758
1759 ge::Status ret = ge_session_->RunGraphWithStreamAsync(graph_id, stream, inputs, *outputs);
1760 if (ret != ge::GRAPH_SUCCESS) {
1761 MS_LOG(ERROR) << "Call GE RunGraphWithStreamAsync Failed, ret is: " << ret;
1762 return false;
1763 }
1764 if (!context_manager_->SyncStream(stream)) {
1765 MS_LOG(ERROR) << "Sync stream for RunGraphWithStreamAsync failed";
1766 return false;
1767 }
1768 struct timeval end_time;
1769 (void)gettimeofday(&end_time, nullptr);
1770 const uint64_t kUSecondInSecond = 1000000;
1771 uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
1772 cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
1773 MS_LOG(INFO) << "Call GE RunGraphWithStreamAsync Success in " << cost << " us, GE outputs num: " << outputs->size()
1774 << ", graph id: " << graph_id;
1775
1776 return true;
1777 }
1778
RunGraph(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs,const std::map<string,string> &)1779 bool GeGraphExecutor::RunGraph(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs,
1780 std::vector<tensor::Tensor> *outputs,
1781 const std::map<string, string> & /* compile_options */) {
1782 if (outputs == nullptr) {
1783 MS_LOG(ERROR) << " Input param is nullptr.";
1784 return false;
1785 }
1786 MS_LOG(INFO) << "Run ge graph [" << graph_id << "] with " << inputs.size() << " inputs";
1787 for (size_t i = 0; i < inputs.size(); i++) {
1788 auto &input = inputs[i];
1789 MS_LOG(INFO) << "Input " << i << " shape " << input.shape_c() << ", datatype " << input.data_type();
1790 }
1791
1792 if (ref_mode_flag_ != transform::RefModeFlag::kRefModeNone) {
1793 return RunGraphRefMode(graph_id, inputs, outputs);
1794 }
1795 std::vector<::ge::Tensor> ge_inputs;
1796 for (size_t i = 0; i < inputs.size(); i++) {
1797 auto &input = inputs[i];
1798 auto ge_tensor =
1799 transform::TransformUtil::ConvertTensor(std::make_shared<tensor::Tensor>(input), kOpFormat_NCHW, false);
1800 if (ge_tensor == nullptr) {
1801 MS_LOG(ERROR) << "Failed to converter input " << i << " ME Tensor to GE Tensor";
1802 return false;
1803 }
1804 ge_inputs.emplace_back(*ge_tensor);
1805 }
1806 for (auto &item : ref_data_infos_) {
1807 ge_inputs.emplace_back(item.ge_tensor);
1808 }
1809 std::vector<::ge::Tensor> ge_outputs;
1810 auto time_start = std::chrono::system_clock::now();
1811 auto ret = !is_data_flow_graph_ ? RunGeGraphAsync(graph_id, ge_inputs, &ge_outputs)
1812 : RunDataFlowGraphAsync(graph_id, ge_inputs, &ge_outputs);
1813 if (!ret) {
1814 MS_LOG(ERROR) << "Exec compute graph failed, graph id " << graph_id;
1815 return false;
1816 }
1817 auto time_cost =
1818 std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now() - time_start).count();
1819 MS_LOG(INFO) << "Call GE RunGraph Success in " << time_cost << " us, graph id " << graph_id
1820 << " the GE outputs num is: " << ge_outputs.size();
1821
1822 if (!outputs->empty()) {
1823 if (outputs->size() != ge_outputs.size()) {
1824 MS_LOG(ERROR) << "Invalid output size, outputs' size " << outputs->size() << "ge tensor size "
1825 << ge_outputs.size();
1826 return false;
1827 }
1828 for (size_t i = 0; i < outputs->size(); ++i) {
1829 const auto &tensor = ge_outputs[i];
1830 auto &output = (*outputs)[i];
1831 if (output.Size() < LongToSize(UlongToLong(tensor.GetSize()))) {
1832 MS_LOG(EXCEPTION) << "Output node " << i << "'s mem size " << output.Size()
1833 << " is less than actual output size " << tensor.GetSize();
1834 }
1835 if ((*outputs)[i].data_c() == nullptr) {
1836 MS_LOG(ERROR) << "Output data ptr is nullptr.";
1837 return false;
1838 }
1839 auto mem_ret = common::huge_memcpy(reinterpret_cast<uint8_t *>(output.data_c()), output.Size(), tensor.GetData(),
1840 tensor.GetSize());
1841 if (mem_ret != EOK) {
1842 MS_LOG(ERROR) << "Failed to copy output data, dst size: " << output.Size()
1843 << ", src size: " << tensor.GetSize();
1844 return false;
1845 }
1846 }
1847 } else {
1848 for (size_t i = 0; i < ge_outputs.size(); i++) {
1849 auto &ge_tensor = ge_outputs[i];
1850 auto ms_tensor = ConvertGeTensorNoCopy(&ge_tensor, graph_id, i);
1851 if (ms_tensor == nullptr) {
1852 MS_LOG(ERROR) << "Failed to converter output " << i << " GE Tensor to ME Tensor";
1853 return false;
1854 }
1855 MS_LOG(INFO) << "Output " << i << " shape " << tensor::ShapeToString(ms_tensor->shape_c()) << ", datatype "
1856 << ms_tensor->data_type();
1857 outputs->push_back(*ms_tensor);
1858 }
1859 }
1860 graph_inputs_[graph_id] = inputs;
1861 graph_outputs_[graph_id] = *outputs;
1862 MS_LOG(INFO) << "GE run graph " << graph_id << " end.";
1863 return true;
1864 }
1865
GetInputInfos(uint32_t graph_id)1866 std::vector<tensor::Tensor> GeGraphExecutor::GetInputInfos(uint32_t graph_id) {
1867 return graph_inputs_.find(graph_id) != graph_inputs_.end() ? graph_inputs_.at(graph_id)
1868 : std::vector<tensor::Tensor>();
1869 }
1870
ConvertGeTensorNoCopy(::ge::Tensor * ge_tensor_ptr,uint32_t graph_id,size_t idx)1871 tensor::TensorPtr GeGraphExecutor::ConvertGeTensorNoCopy(::ge::Tensor *ge_tensor_ptr, uint32_t graph_id, size_t idx) {
1872 auto &ge_tensor = *ge_tensor_ptr;
1873 auto ge_tensor_desc = ge_tensor.GetTensorDesc();
1874 auto me_shape = transform::TransformUtil::ConvertGeShape(ge_tensor_desc.GetShape());
1875 if (original_graph_outputs_.find(graph_id) == original_graph_outputs_.end()) {
1876 MS_LOG(ERROR) << "Graph original outputs with the given graph id is not found.";
1877 return nullptr;
1878 }
1879 auto original_outputs = original_graph_outputs_[graph_id];
1880 if (idx >= original_outputs.size()) {
1881 MS_LOG(ERROR) << "Graph output index is out of range.";
1882 return nullptr;
1883 }
1884 TypeId type_id = static_cast<TypeId>(original_outputs[idx]->data_type_c());
1885 if (type_id == kTypeUnknown) {
1886 MS_LOG(ERROR) << "Could not convert Ge Tensor because of unsupported data type: "
1887 << static_cast<int>(ge_tensor_desc.GetDataType());
1888 return nullptr;
1889 }
1890 if (ge_tensor_desc.GetPlacement() != ::ge::kPlacementHost) {
1891 MS_LOG(ERROR) << "It is not supported that graph output data's placement is device now.";
1892 return nullptr;
1893 }
1894 auto &&ge_data_uni = ge_tensor.ResetData();
1895 auto deleter = ge_data_uni.get_deleter();
1896 auto ge_data = ge_data_uni.release();
1897 if (ge_data == nullptr) {
1898 MS_LOG(ERROR) << "Ge data cannot be nullptr";
1899 return nullptr;
1900 }
1901 constexpr int64_t kTensorAlignBytes = 64;
1902 if (reinterpret_cast<uintptr_t>(ge_data) % kTensorAlignBytes != 0) {
1903 MS_LOG(ERROR) << "Skip zero-copy ge tensor " << reinterpret_cast<uintptr_t>(ge_data)
1904 << ", bytes not aligned with expected.";
1905 return nullptr;
1906 }
1907 int64_t elem_num = 1;
1908 for (size_t i = 0; i < me_shape.size(); ++i) {
1909 elem_num *= me_shape[i];
1910 }
1911 if (GetTypeByte(TypeIdToType(type_id)) * elem_num != ge_tensor.GetSize()) {
1912 MS_LOG(ERROR) << "Output datatype error! Output tensor size from GE RunGraph does not match.";
1913 return nullptr;
1914 }
1915 auto tensor_data = std::make_shared<TensorRefData>(ge_data, elem_num, ge_tensor.GetSize(), me_shape.size(), deleter);
1916 return std::make_shared<tensor::Tensor>(type_id, me_shape, tensor_data);
1917 }
1918
GetOutputInfos(uint32_t graph_id)1919 std::vector<tensor::Tensor> GeGraphExecutor::GetOutputInfos(uint32_t graph_id) {
1920 return graph_outputs_.find(graph_id) != graph_outputs_.end() ? graph_outputs_.at(graph_id)
1921 : std::vector<tensor::Tensor>();
1922 }
1923
CreateAsCustomFuncGraph(const FuncGraphPtr & func_graph,const std::map<std::string,std::string> & graph_options)1924 bool GeGraphExecutor::CreateAsCustomFuncGraph(const FuncGraphPtr &func_graph,
1925 const std::map<std::string, std::string> &graph_options) {
1926 Buffer buffer;
1927 auto files = ReadFileNames(build_cache_dir_);
1928 for (auto &file : files) {
1929 if (file.find(".om") != std::string::npos && file.find(graph_name_) != std::string::npos) {
1930 auto om_path = build_cache_dir_ + "/" + file;
1931 buffer = ReadFile(om_path);
1932 break;
1933 }
1934 }
1935 if (buffer.DataSize() == 0 || buffer.Data() == nullptr) {
1936 MS_LOG(ERROR) << "Failed to read model buffer file, model cache " << build_cache_dir_;
1937 return false;
1938 }
1939 std::map<std::string, ValuePtr> attr_map;
1940 SetOptionsIntoOfflineModel(session_options_, &attr_map);
1941 std::vector<std::string> ref_datas;
1942 std::transform(ref_data_infos_.begin(), ref_data_infos_.end(), std::back_inserter(ref_datas),
1943 [](auto &item) { return item.name; });
1944 DynKVCacheSaveInfo save_info;
1945 save_info.seq_length_dyn = dyn_kv_cache_info_.seq_length_dyn;
1946 save_info.batch_size_dyn = dyn_kv_cache_info_.batch_size_dyn;
1947 save_info.kv_cache_layout = dyn_kv_cache_info_.kv_cache_layout;
1948
1949 if (!CustomAscendUtils::CreateCustomFuncGraph(func_graph, buffer, graph_name_, attr_map, ref_datas, save_info)) {
1950 MS_LOG(ERROR) << "Create custom func graph failed";
1951 return false;
1952 }
1953 return true;
1954 }
1955
OfflineBuildGraph(const FuncGraphPtr & graph)1956 bool GeGraphExecutor::OfflineBuildGraph(const FuncGraphPtr &graph) {
1957 if (ref_mode_flag_ == transform::RefModeFlag::kRefModeNone) {
1958 MS_LOG(INFO) << "parameter_as_refdata in ascend_context is none, skip offline build graph";
1959 return true;
1960 }
1961 MS_LOG(INFO) << "Set offline mode";
1962 std::map<std::string, std::string> extra_session_options;
1963 if (!SetOfflineBuildModelCacheDir(&extra_session_options)) {
1964 return false;
1965 }
1966 if (!CreateSession(extra_session_options)) {
1967 MS_LOG(ERROR) << "Failed to create ge session";
1968 return false;
1969 }
1970
1971 if (!SetDynamicKVCache(graph)) {
1972 MS_LOG(ERROR) << "Failed to init dynamic KVCache info";
1973 return false;
1974 }
1975 uint32_t graph_id = 0;
1976 std::map<std::string, std::string> ge_options;
1977 GetGeGraphOptions(graph, &ge_options);
1978 auto df_graph = CompileGraphCommon(graph, &ge_options);
1979 if (df_graph == nullptr) {
1980 MS_LOG(ERROR) << "Input param graph is nullptr.";
1981 return false;
1982 }
1983 if (!AddGraph(df_graph, ge_options, &graph_id)) {
1984 MS_LOG(ERROR) << "Failed to add compute graph, graph name " << graph->ToString();
1985 return false;
1986 }
1987 compute_graph_id_list_.push_back(graph_id);
1988 MS_LOG(INFO) << "Call GE CompileGraph start, graph id " << graph_id;
1989 ge::Status ret = ge_session_->CompileGraph(graph_id);
1990 if (ret != ge::GRAPH_SUCCESS) {
1991 MS_LOG(ERROR) << "Call GE CompileGraph Failed: " << ge::GEGetErrorMsg();
1992 return false;
1993 }
1994 MS_LOG(INFO) << "Call GE CompileGraph end, graph id " << graph_id;
1995 if (!CreateAsCustomFuncGraph(graph, ge_options)) {
1996 MS_LOG(ERROR) << "Failed to CreateAsCustomFuncGraph";
1997 return false;
1998 }
1999 return true;
2000 }
2001
2002 std::map<int64_t, std::shared_ptr<GeSessionContext>> GeSessionManager::ge_session_map_;
2003 std::mutex GeSessionManager::session_mutex_;
2004
CreateGeSession(int64_t session_id,const std::map<std::string,std::string> & session_options)2005 std::shared_ptr<ge::Session> GeSessionManager::CreateGeSession(
2006 int64_t session_id, const std::map<std::string, std::string> &session_options) {
2007 std::shared_ptr<ge::Session> ge_session = nullptr;
2008 if (session_id == kUnkonwnSessionId) {
2009 ge_session = std::make_shared<ge::Session>(session_options);
2010 if (ge_session == nullptr) {
2011 MS_LOG(ERROR) << "Failed to create ge session";
2012 return nullptr;
2013 }
2014 MS_LOG(INFO) << "Create ge session successfully, which will not be shared with other graph";
2015 return ge_session;
2016 }
2017 std::lock_guard<std::mutex> lock(session_mutex_);
2018 auto s_it = ge_session_map_.find(session_id);
2019 if (s_it != ge_session_map_.end() && s_it->second != nullptr) {
2020 ge_session = s_it->second->ge_session.lock();
2021 }
2022 if (ge_session == nullptr) {
2023 for (auto &option : session_options) {
2024 MS_LOG(INFO) << "GE Session (lite session id " << session_id << ") option " << option.first << " = "
2025 << option.second;
2026 }
2027 ge_session = std::make_shared<ge::Session>(session_options);
2028 if (ge_session == nullptr) {
2029 MS_LOG(ERROR) << "Failed to create ge session";
2030 return nullptr;
2031 }
2032 auto session_context = std::make_shared<GeSessionContext>();
2033 if (session_context == nullptr) {
2034 MS_LOG(ERROR) << "Failed to create GeSessionContext";
2035 return nullptr;
2036 }
2037 session_context->ge_session = ge_session;
2038 session_context->session_options = session_options;
2039 ge_session_map_[session_id] = session_context;
2040 MS_LOG(INFO) << "Create ge session successfully, lite session id: " << session_id;
2041 } else {
2042 auto map_as_string = [](const std::map<std::string, std::string> &options) {
2043 std::stringstream ss;
2044 ss << "{";
2045 for (auto &item : options) {
2046 ss << "" << item.first << ":" << item.second << ",";
2047 }
2048 ss << "}";
2049 return ss.str();
2050 };
2051 auto old_options = s_it->second->session_options;
2052 if (old_options != session_options) {
2053 MS_LOG(ERROR)
2054 << "Session options is not equal in diff config infos when models' weights are shared, last session options: "
2055 << map_as_string(old_options) << ", current session options: " << map_as_string(session_options);
2056 return nullptr;
2057 }
2058 MS_LOG(INFO) << "Get ge session from session map, lite session id: " << session_id;
2059 }
2060 return ge_session;
2061 }
2062
UpdateSessionVariables(int64_t session_id,const std::vector<std::string> & graph_variables)2063 std::set<std::string> GeSessionManager::UpdateSessionVariables(int64_t session_id,
2064 const std::vector<std::string> &graph_variables) {
2065 std::set<std::string> new_variables;
2066 if (session_id == kUnkonwnSessionId) {
2067 std::transform(graph_variables.begin(), graph_variables.end(), std::inserter(new_variables, new_variables.begin()),
2068 [](const auto &item) { return item; });
2069 return new_variables;
2070 }
2071 std::lock_guard<std::mutex> lock(session_mutex_);
2072 std::shared_ptr<ge::Session> ge_session = nullptr;
2073 auto s_it = ge_session_map_.find(session_id);
2074 if (s_it != ge_session_map_.end() && s_it->second != nullptr) {
2075 ge_session = s_it->second->ge_session.lock();
2076 }
2077 if (ge_session == nullptr) {
2078 std::transform(graph_variables.begin(), graph_variables.end(), std::inserter(new_variables, new_variables.begin()),
2079 [](const auto &item) { return item; });
2080 return new_variables;
2081 }
2082 auto ¤t_session_variables = s_it->second->session_variables;
2083 for (auto &item : graph_variables) {
2084 if (current_session_variables.find(item) == current_session_variables.end()) {
2085 new_variables.insert(item);
2086 current_session_variables.insert(item);
2087 }
2088 }
2089 return new_variables;
2090 }
2091
TryReleaseGeSessionContext(int64_t session_id)2092 void GeSessionManager::TryReleaseGeSessionContext(int64_t session_id) {
2093 std::lock_guard<std::mutex> lock(session_mutex_);
2094 auto s_it = ge_session_map_.find(session_id);
2095 if (s_it != ge_session_map_.end()) {
2096 if (s_it->second != nullptr) {
2097 auto ge_session = s_it->second->ge_session.lock();
2098 if (ge_session == nullptr) {
2099 ge_session_map_.erase(s_it);
2100 }
2101 } else {
2102 ge_session_map_.erase(s_it);
2103 }
2104 }
2105 }
2106
GetGeSessionContext(int64_t session_id)2107 std::shared_ptr<GeSessionContext> GeSessionManager::GetGeSessionContext(int64_t session_id) {
2108 std::lock_guard<std::mutex> lock(session_mutex_);
2109 auto s_it = ge_session_map_.find(session_id);
2110 if (s_it != ge_session_map_.end()) {
2111 return s_it->second;
2112 }
2113 return nullptr;
2114 }
2115
GeGraphExecutorCreator(const std::shared_ptr<Context> & ctx,const ConfigInfos & config_infos)2116 static std::shared_ptr<device::GraphExecutor> GeGraphExecutorCreator(const std::shared_ptr<Context> &ctx,
2117 const ConfigInfos &config_infos) {
2118 auto ge_executor = std::make_shared<GeGraphExecutor>(ctx, config_infos);
2119 if (ge_executor == nullptr || !ge_executor->Init()) {
2120 MS_LOG(ERROR) << "Failed to init GeGraphExecutor";
2121 return nullptr;
2122 }
2123 return ge_executor;
2124 }
2125
2126 REG_DELEGATE(kAscend, kProviderGe, GeGraphExecutorCreator)
2127 } // namespace mindspore
2128