1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * Limitations under the License.
15 */
16
17 #include "transform/graph_ir/graph_runner.h"
18 #include <algorithm>
19 #include <string>
20 #include <memory>
21 #include <set>
22
23 #ifndef ENABLE_LITE_ACL
24 #include "pybind11/pybind11.h"
25 #endif
26 #include "utils/log_adapter.h"
27 #include "include/common/utils/config_manager.h"
28 #include "sys/time.h"
29 #include "include/common/utils/utils.h"
30 #include "include/common/utils/callbacks.h"
31 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
32 #include "transform/graph_ir/callbacks_ge.h"
33 #include "external/ge_common/ge_api_error_codes.h"
34 #include "graph/graph.h"
35 #endif
36 #include "utils/ms_context.h"
37 #include "include/common/utils/scoped_long_running.h"
38
39 #ifndef ENABLE_LITE_ACL
40 namespace py = pybind11;
41 #endif
42 namespace mindspore {
43 namespace transform {
NewSession(const SessionOptions & sess_options)44 std::shared_ptr<::ge::Session> GraphRunner::NewSession(const SessionOptions &sess_options) {
45 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
46 std::shared_ptr<::ge::Session> ret;
47 auto ms_context = MsContext::GetInstance();
48 MS_EXCEPTION_IF_NULL(ms_context);
49 if (ms_context->backend_policy() == "ge" || ms_context->backend_policy() == "ms") {
50 ret = std::make_shared<::ge::Session>(sess_options);
51 if (ret == nullptr) {
52 MS_LOG(EXCEPTION) << "Create GE session failed!";
53 }
54 MS_LOG(INFO) << "Create new GE session success!";
55 return ret;
56 }
57 #endif
58
59 MS_LOG(DEBUG) << "no GE client, return nullptr!";
60 return nullptr;
61 }
62
GraphRunner(const GraphRunnerOptions & options)63 GraphRunner::GraphRunner(const GraphRunnerOptions &options)
64 : options_(options), graph_manager_(DfGraphManager::GetInstance()) {
65 if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) {
66 MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode";
67 }
68 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
69 auto ms_context = MsContext::GetInstance();
70 MS_EXCEPTION_IF_NULL(ms_context);
71 #endif
72 if (options.sess_ptr != nullptr) {
73 sess_ = options.sess_ptr;
74 } else {
75 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
76 if (ms_context->backend_policy() == "ge") {
77 sess_ = NewSession(options.options);
78 if (sess_ == nullptr) {
79 MS_LOG(EXCEPTION) << "Graph runner sess_ is nullptr!";
80 }
81 }
82 #endif
83 }
84
85 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
86 if (ms_context->backend_policy() != "ge") {
87 return;
88 }
89 #ifndef ENABLE_LITE_ACL
90 // register the callback function
91 MS_EXCEPTION_IF_NULL(sess_);
92 if (sess_->RegisterCallBackFunc(callbacks::kCheckPoint, callbacks::CheckpointSaveCallback) != ::ge::GRAPH_SUCCESS) {
93 MS_LOG(EXCEPTION) << "Register callback failed!";
94 }
95 if (sess_->RegisterCallBackFunc(callbacks::kSummary, callbacks::SummarySaveCallback) != ::ge::GRAPH_SUCCESS) {
96 MS_LOG(EXCEPTION) << "Register summary callback failed!";
97 }
98 #endif
99 #endif
100 }
101
AddGraph(const std::string & name)102 Status GraphRunner::AddGraph(const std::string &name) {
103 DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name);
104 if (wrap_ptr == nullptr) {
105 MS_LOG(WARNING) << "Get graph form DfGraphManager failed!";
106 return Status::NOT_FOUND;
107 }
108
109 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
110 auto ms_context = MsContext::GetInstance();
111 MS_EXCEPTION_IF_NULL(ms_context);
112 if (ms_context->backend_policy() != "ge" && ms_context->backend_policy() != "ms") {
113 return Status::SUCCESS;
114 }
115 std::set<string> saved_graph = graph_manager_.GetSavedGraphs();
116 auto iter_find = saved_graph.find(std::to_string(wrap_ptr->id_));
117 if (iter_find != saved_graph.end()) {
118 MS_LOG(INFO) << "The graph is already added, graph name: " << name;
119 return Status::SUCCESS;
120 }
121
122 graph_manager_.AddSavedGraphs(std::to_string(wrap_ptr->id_));
123 if (!wrap_ptr->is_added_to_ge_session_) {
124 MS_LOG(INFO) << "Add the graph " << (*wrap_ptr).name_ << " to GE, it's id is: " << (*wrap_ptr).id_;
125 MS_EXCEPTION_IF_NULL(sess_);
126 auto ret = sess_->AddGraph(static_cast<uint32_t>(wrap_ptr->id_), *(wrap_ptr->graph_ptr_), wrap_ptr->options_);
127 if (ret != ge::GRAPH_SUCCESS) {
128 MS_LOG(ERROR) << "AddGraph to Ge Session failed, ret: " << ret;
129 return Status::FAILED;
130 }
131 wrap_ptr->is_added_to_ge_session_ = true;
132 }
133 #endif
134 return Status::SUCCESS;
135 }
136
RunGraph(const RunOptions & options,const std::vector<GeTensorPtr> & inputs,std::vector<GeTensorPtr> * outputs)137 Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<GeTensorPtr> &inputs,
138 std::vector<GeTensorPtr> *outputs) {
139 std::string name = options.name;
140 if (name.empty()) {
141 MS_LOG(ERROR) << "The graph name is null";
142 return Status::INVALID_ARGUMENT;
143 }
144
145 DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name);
146 if (wrap_ptr == nullptr) {
147 MS_LOG(INFO) << "Get graph form DfGraphManager failed!";
148 return Status::NOT_FOUND;
149 }
150
151 if (wrap_ptr->graph_ptr_ == nullptr) {
152 MS_LOG(WARNING) << "The graph is null";
153 return Status::NOT_FOUND;
154 }
155
156 // call ::ge::RunGraph() to exec a graph;
157 std::vector<GeTensor> ge_inputs;
158 std::vector<GeTensor> ge_outputs;
159
160 (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs),
161 [](const GeTensorPtr &i) { return *i; });
162
163 MS_LOG(INFO) << "Run the graph " << name << " in GE with " << ge_inputs.size() << " inputs";
164
165 struct timeval start_time;
166 struct timeval end_time;
167 (void)gettimeofday(&start_time, nullptr);
168
169 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
170 auto ms_context = MsContext::GetInstance();
171 MS_EXCEPTION_IF_NULL(ms_context);
172 if (ms_context->backend_policy() == "ge") {
173 if (sess_ == nullptr) {
174 MS_LOG(ERROR) << "The GE session is null, can't run the graph!";
175 return Status::FAILED;
176 }
177
178 ::ge::Status ret = sess_->RunGraph(static_cast<uint32_t>(wrap_ptr->id_), ge_inputs, ge_outputs);
179 if (ret != ::ge::GRAPH_SUCCESS) {
180 MS_LOG(ERROR) << "Call GE RunGraph Failed, ret is: " << ret;
181 return Status::FAILED;
182 }
183 }
184 #else
185 ge_outputs.swap(ge_inputs);
186 #endif
187
188 (void)gettimeofday(&end_time, nullptr);
189 const uint64_t kUSecondInSecond = 1000000;
190 uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
191 cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
192 MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << ge_outputs.size();
193
194 (void)std::transform(ge_outputs.begin(), ge_outputs.end(), std::back_inserter(*outputs),
195 [](const GeTensor &ge_tensor) { return std::make_shared<GeTensor>(ge_tensor); });
196
197 return Status::SUCCESS;
198 }
199
RunGraphAsync(const RunOptions & options,const std::vector<GeTensorPtr> & inputs,std::vector<GeTensorPtr> * outputs)200 Status GraphRunner::RunGraphAsync(const RunOptions &options, const std::vector<GeTensorPtr> &inputs,
201 std::vector<GeTensorPtr> *outputs) {
202 std::string name = options.name;
203 if (name.empty()) {
204 MS_LOG(ERROR) << "The graph name is null";
205 return Status::INVALID_ARGUMENT;
206 }
207
208 DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name);
209 if (wrap_ptr == nullptr) {
210 MS_LOG(WARNING) << "Get graph form DfGraphManager failed!";
211 return Status::NOT_FOUND;
212 }
213
214 if (wrap_ptr->graph_ptr_ == nullptr) {
215 MS_LOG(WARNING) << "The graph is null";
216 return Status::NOT_FOUND;
217 }
218
219 // call ge::RunGraphAsync() to exec a graph;
220 std::vector<GeTensor> ge_inputs;
221 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
222 auto ms_context = MsContext::GetInstance();
223 MS_EXCEPTION_IF_NULL(ms_context);
224 if (ConfigManager::GetInstance().dataset_mode() != DS_SINK_MODE) {
225 (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs),
226 [](const GeTensorPtr &i) { return *i; });
227 }
228 #endif
229 MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs";
230
231 struct timeval start_time;
232 struct timeval end_time;
233 (void)gettimeofday(&start_time, nullptr);
234
235 #if (defined ENABLE_D) || (defined ENABLE_LITE_ACL)
236 std::mutex mutex;
237 std::condition_variable condition;
238 bool is_finished = false;
239 bool end_of_sequence = false;
240 std::unique_lock<std::mutex> lock(mutex);
241 auto call_back = [=, &is_finished, &end_of_sequence, &condition](ge::Status ge_status,
242 std::vector<ge::Tensor> &ge_outputs) {
243 if (ge_status == Status::SUCCESS) {
244 for (size_t i = 0; i < ge_outputs.size(); ++i) {
245 (void)outputs->emplace_back(std::make_shared<ge::Tensor>(ge_outputs[i]));
246 }
247 is_finished = true;
248 } else if (ge_status == ge::END_OF_SEQUENCE) {
249 MS_LOG(WARNING) << "RunAsync out of range: End of sequence.";
250 end_of_sequence = true;
251 } else {
252 MS_LOG(ERROR) << "RunAsync failed.";
253 }
254 condition.notify_all();
255 return;
256 };
257 if (ms_context->backend_policy() == "ge") {
258 if (sess_ == nullptr) {
259 MS_LOG(ERROR) << "The GE session is null, can't run the graph!";
260 return Status::FAILED;
261 }
262 ge::Status ret = sess_->RunGraphAsync(static_cast<uint32_t>(wrap_ptr->id_), ge_inputs, call_back);
263 if (ret != ge::GRAPH_SUCCESS) {
264 MS_LOG(ERROR) << "Call GE RunGraphAsync Failed, ret is: " << ret;
265 return Status::FAILED;
266 }
267 if (!is_finished) {
268 condition.wait(lock);
269 }
270 if (end_of_sequence) {
271 throw(std::runtime_error("End of sequence."));
272 }
273 if (!is_finished) {
274 MS_LOG(ERROR) << "Call GE RunGraphAsync failed.";
275 return Status::FAILED;
276 }
277 }
278 #endif
279 (void)gettimeofday(&end_time, nullptr);
280 const uint64_t kUSecondInSecond = 1000000;
281 uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
282 cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
283 MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << outputs->size();
284
285 return Status::SUCCESS;
286 }
287
RunGraph(const RunOptions & options,const std::vector<MeTensorPtr> & inputs,std::vector<MeTensorPtr> * const outputs)288 Status GraphRunner::RunGraph(const RunOptions &options, const std::vector<MeTensorPtr> &inputs,
289 std::vector<MeTensorPtr> *const outputs) {
290 std::vector<GeTensorPtr> ge_inputs;
291 for (auto it : inputs) {
292 MS_EXCEPTION_IF_NULL(it);
293 MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize();
294 auto shape = (*it).shape();
295 std::string shape_str;
296 for (const auto &elem : shape) {
297 shape_str += std::to_string(elem);
298 shape_str += " ";
299 }
300 MS_LOG(INFO) << "inputs tensor's shape is: { " << shape_str << "}";
301
302 auto ge_tensor_ptr = TransformUtil::ConvertTensor(it, kOpFormat_NCHW);
303 if (ge_tensor_ptr != nullptr) {
304 (void)ge_inputs.emplace_back(ge_tensor_ptr);
305 } else {
306 MS_LOG(INFO) << "Convert input Me tensor to Ge tensor failed. Abort this graph";
307 return Status::FAILED;
308 }
309 }
310
311 std::vector<GeTensorPtr> ge_outputs;
312 Status ret;
313 {
314 // Release GIL before calling into (potentially long-running) C++ code
315 #ifndef ENABLE_LITE_ACL
316 py::gil_scoped_release release;
317 #endif
318 ret = RunGraph(options, ge_inputs, &ge_outputs);
319 }
320 if (ret != Status::SUCCESS) {
321 return ret;
322 } else {
323 // convert GeTensor to MeTensor
324 for (auto &it : ge_outputs) {
325 auto tensor = TransformUtil::ConvertGeTensor(it);
326 if (tensor != nullptr) {
327 (void)outputs->emplace_back(tensor);
328 }
329 }
330 MS_LOG(INFO) << "Return Me tensor outputs num is: " << outputs->size();
331 return Status::SUCCESS;
332 }
333 }
334
RunGraphWithStreamAsync(const RunOptions & options,void * stream,const std::vector<GeTensor> & inputs,std::vector<GeTensor> * outputs)335 Status GraphRunner::RunGraphWithStreamAsync(const RunOptions &options, void *stream,
336 const std::vector<GeTensor> &inputs, std::vector<GeTensor> *outputs) {
337 MS_EXCEPTION_IF_NULL(outputs);
338 std::string name = options.name;
339 if (name.empty()) {
340 MS_LOG(ERROR) << "The graph name is null";
341 return Status::INVALID_ARGUMENT;
342 }
343
344 DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name);
345 if (wrap_ptr == nullptr) {
346 MS_LOG(ERROR) << "Get graph form DfGraphManager failed!";
347 return Status::NOT_FOUND;
348 }
349
350 if (wrap_ptr->graph_ptr_ == nullptr) {
351 MS_LOG(WARNING) << "The graph is null";
352 return Status::NOT_FOUND;
353 }
354
355 auto ms_context = MsContext::GetInstance();
356 MS_EXCEPTION_IF_NULL(ms_context);
357
358 MS_LOG(INFO) << "Run the graph in GE with " << inputs.size() << " inputs";
359 struct timeval start_time, end_time;
360 (void)gettimeofday(&start_time, nullptr);
361
362 if (ms_context->backend_policy() == "ge" || ms_context->backend_policy() == "ms") {
363 if (sess_ == nullptr) {
364 MS_LOG(ERROR) << "The GE session is null, can't run the graph!";
365 return Status::FAILED;
366 }
367 std::lock_guard<std::mutex> lock(wrap_ptr->mutex_);
368 wrap_ptr->times_++;
369 ge::Status ret = sess_->RunGraphWithStreamAsync(static_cast<uint32_t>(wrap_ptr->id_), stream, inputs, *outputs);
370 if (ret != ge::GRAPH_SUCCESS) {
371 MS_LOG(ERROR) << "Call GE RunGraphWithStreamAsync Failed, ret is: " << ret;
372 return Status::FAILED;
373 }
374 }
375
376 (void)gettimeofday(&end_time, nullptr);
377 const uint64_t kUSecondInSecond = 1000000;
378 uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
379 cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
380 MS_LOG(INFO) << "Call GE RunGraph Success in " << cost << " us, the GE outputs num is: " << outputs->size();
381
382 return Status::SUCCESS;
383 }
384
RegisterExternalAllocator(const void * const stream,GeAllocatorPtr allocator)385 Status GraphRunner::RegisterExternalAllocator(const void *const stream, GeAllocatorPtr allocator) {
386 if (sess_ == nullptr) {
387 MS_LOG(ERROR) << "The GE session is null, can't call GE RegisterExternalAllocator!";
388 return Status::FAILED;
389 }
390 ge::Status ret = sess_->RegisterExternalAllocator(stream, allocator);
391 if (ret != ge::GRAPH_SUCCESS) {
392 MS_LOG(ERROR) << "Call GE RegisterExternalAllocator Failed, ret is: " << ret;
393 return Status::FAILED;
394 }
395 is_allocator_registered = true;
396 return Status::SUCCESS;
397 }
398
UnregisterExternalAllocator(const void * const stream)399 Status GraphRunner::UnregisterExternalAllocator(const void *const stream) {
400 if (sess_ == nullptr) {
401 MS_LOG(ERROR) << "The GE session is null, can't call GE UnregisterExternalAllocator!";
402 return Status::FAILED;
403 }
404 ge::Status ret = sess_->UnregisterExternalAllocator(stream);
405 if (ret != ge::GRAPH_SUCCESS) {
406 MS_LOG(ERROR) << "Call GE UnregisterExternalAllocator Failed, ret is: " << ret;
407 return Status::FAILED;
408 }
409 is_allocator_registered = false;
410 return Status::SUCCESS;
411 }
412
CompileGraph(const RunOptions & options)413 Status GraphRunner::CompileGraph(const RunOptions &options) {
414 DfGraphWrapperPtr wrap_ptr = nullptr;
415 auto name = options.name;
416 auto ret = GetWrapper(name, &wrap_ptr);
417 if (ret != Status::SUCCESS) {
418 return ret;
419 }
420
421 MS_LOG(INFO) << "Start compile graph " << name;
422 ge::Status ge_ret = sess_->CompileGraph(static_cast<uint32_t>(wrap_ptr->id_));
423 if (ge_ret != ge::GRAPH_SUCCESS) {
424 MS_LOG(ERROR) << "Call GE CompileGraph Failed, ret is: " << ge_ret;
425 return Status::FAILED;
426 }
427
428 MS_LOG(INFO) << "Compile graph " << name << " success, start to get graph summary.";
429 return Status::SUCCESS;
430 }
431
CompileGraph(const RunOptions & options,::ge::CompiledGraphSummaryPtr * graph_summary)432 Status GraphRunner::CompileGraph(const RunOptions &options, ::ge::CompiledGraphSummaryPtr *graph_summary) {
433 MS_EXCEPTION_IF_NULL(graph_summary);
434 DfGraphWrapperPtr wrap_ptr = nullptr;
435 auto name = options.name;
436 auto ret = GetWrapper(name, &wrap_ptr);
437 if (ret != Status::SUCCESS) {
438 return ret;
439 }
440
441 MS_LOG(INFO) << "Start compile graph " << name;
442 ge::Status ge_ret = sess_->CompileGraph(static_cast<uint32_t>(wrap_ptr->id_));
443 if (ge_ret != ge::GRAPH_SUCCESS) {
444 MS_LOG(ERROR) << "Call GE CompileGraph Failed, ret is: " << ge_ret;
445 return Status::FAILED;
446 }
447
448 MS_LOG(INFO) << "Compile graph " << name << " success, start to get graph summary.";
449 *graph_summary = sess_->GetCompiledGraphSummary(static_cast<uint32_t>(wrap_ptr->id_));
450 MS_EXCEPTION_IF_NULL(*graph_summary);
451 MS_LOG(INFO) << "Get graph summary success for graph " << name;
452 return Status::SUCCESS;
453 }
454
SetConstMemory(const RunOptions & options,const void * const memory,size_t size)455 Status GraphRunner::SetConstMemory(const RunOptions &options, const void *const memory, size_t size) {
456 DfGraphWrapperPtr wrap_ptr = nullptr;
457 auto name = options.name;
458 auto ret = GetWrapper(name, &wrap_ptr);
459 if (ret != Status::SUCCESS) {
460 return ret;
461 }
462 ge::Status ge_ret = sess_->SetGraphConstMemoryBase(static_cast<uint32_t>(wrap_ptr->id_), memory, size);
463 if (ge_ret != ge::GRAPH_SUCCESS) {
464 MS_LOG(ERROR) << "Call GE SetGraphConstMemoryBase Failed, ret is: " << ge_ret;
465 return Status::FAILED;
466 }
467 return Status::SUCCESS;
468 }
469
UpdateFeatureMemory(const RunOptions & options,const void * const memory,size_t size)470 Status GraphRunner::UpdateFeatureMemory(const RunOptions &options, const void *const memory, size_t size) {
471 DfGraphWrapperPtr wrap_ptr = nullptr;
472 auto name = options.name;
473 auto ret = GetWrapper(name, &wrap_ptr);
474 if (ret != Status::SUCCESS) {
475 return ret;
476 }
477 ge::Status ge_ret = sess_->UpdateGraphFeatureMemoryBase(static_cast<uint32_t>(wrap_ptr->id_), memory, size);
478 if (ge_ret != ge::GRAPH_SUCCESS) {
479 MS_LOG(ERROR) << "Call GE SetGraphConstMemoryBase Failed, ret is: " << ge_ret;
480 return Status::FAILED;
481 }
482 return Status::SUCCESS;
483 }
484
GetWrapper(const std::string & name,DfGraphWrapperPtr * wrapper) const485 Status GraphRunner::GetWrapper(const std::string &name, DfGraphWrapperPtr *wrapper) const {
486 MS_EXCEPTION_IF_NULL(wrapper);
487 if (name.empty()) {
488 MS_LOG(ERROR) << "The graph name is null";
489 return Status::INVALID_ARGUMENT;
490 }
491
492 DfGraphWrapperPtr wrap_ptr = graph_manager_.GetGraphByName(name);
493 if (wrap_ptr == nullptr) {
494 MS_LOG(ERROR) << "Get graph form DfGraphManager failed!";
495 return Status::NOT_FOUND;
496 }
497
498 if (wrap_ptr->graph_ptr_ == nullptr) {
499 MS_LOG(WARNING) << "The graph is null";
500 return Status::NOT_FOUND;
501 }
502
503 if (sess_ == nullptr) {
504 MS_LOG(ERROR) << "The GE session is null, can't run the graph!";
505 return Status::FAILED;
506 }
507
508 *wrapper = wrap_ptr;
509 return Status::SUCCESS;
510 }
511 } // namespace transform
512 } // namespace mindspore
513