1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/litert/cxx_api/context.h"
17 #include <string>
18 #include <memory>
19 #include "include/api/types.h"
20 #include "include/api/data_type.h"
21 #include "include/lite_types.h"
22 #include "src/litert/inner_allocator.h"
23 #include "src/common/log_adapter.h"
24 #include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h"
25
26 namespace mindspore {
27 constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
28 constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
29 constexpr auto kModelOptionNPUEnableFP16 = "mindspore.option.npu.enable_fp16";
30 constexpr auto kModelOptionGPUEnableGLTexture = "mindspore.option.gpu.enable_gl_texture_";
31 constexpr auto kModelOptionGPUGLContext = "mindspore.option.gpu.gl_context_";
32 constexpr auto kModelOptionGPUGLDisplay = "mindspore.option.gpu.gl_display_";
33 constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id";
34 constexpr auto kModelOptionGPURankID = "mindspore.option.gpu.rank_id";
35 constexpr auto kModelOptionGPUGroupSize = "mindspore.option.gpu.group_size";
36 constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
37 constexpr auto kModelOptionProvider = "mindspore.option.provider";
38 constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
39 constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
40 constexpr auto kModelOptionAscendDeviceID = kModelOptionDeviceID;
41 constexpr auto kModelOptionAscendInsertOpCfgPath = "mindspore.option.ascend.insert_op_config_file_path";
42 constexpr auto kModelOptionAscendInputFormat = "mindspore.option.ascend.input_format";
43 constexpr auto kModelOptionAscendInputShapeMap = "mindspore.option.ascend.input_shape_map";
44 constexpr auto kModelOptionAscendInputShape = "mindspore.option.ascend.input_shape";
45 constexpr auto kModelOptionAscendOutputType = "mindspore.option.ascend.output_type";
46 constexpr auto kModelOptionAscendPrecisionMode = "mindspore.option.ascend.precision_mode";
47 constexpr auto kModelOptionAscendOpSelectImplMode = "mindspore.option.ascend.op_select_impl_mode";
48 constexpr auto KModelOptionAscendFusionSwitchCfgPath = "mindspore.option.ascend.fusion_switch_config_file_path";
49 constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dynamic_batch_size";
50 constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size";
51 constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize";
52 constexpr auto kModelOptionAscendRankID = "mindspore.option.ascend.rank_id";
53 constexpr auto kModelOptionNNRTDeviceID = "mindspore.option.nnrt.device_id";
54 constexpr auto kModelOptionNNRTPerformanceMode = "mindspore.option.nnrt.performance_mode";
55 constexpr auto kModelOptionNNRTPriority = "mindspore.option.nnrt.priority";
56 constexpr auto kModelOptionNNRTEnableFP16 = "mindspore.option.nnrt.enable_fp16";
57 constexpr auto kModelOptionNNRTExtensions = "mindspore.option.nnrt.extensions";
58 #ifdef USE_GLOG
59 extern "C" {
60 extern void mindspore_log_init();
61 }
62 #endif
63
Context()64 Context::Context() : data_(std::make_shared<Data>()) {
65 #ifdef USE_GLOG
66 mindspore::mindspore_log_init();
67 #endif
68 }
69
70 template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
GetValue(const std::shared_ptr<DeviceInfoContext::Data> & data,const std::string & key)71 static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
72 static U empty_result;
73 if (data == nullptr) {
74 return empty_result;
75 }
76 auto iter = data->params.find(key);
77 if (iter == data->params.end()) {
78 return empty_result;
79 }
80 #ifndef SUPPORT_NNIE
81 const std::any &value = iter->second;
82 return std::any_cast<const U &>(value);
83 #else
84 const std::experimental::any &value = iter->second;
85 return std::experimental::any_cast<const U &>(value);
86 #endif
87 }
88
SetThreadNum(int32_t thread_num)89 void Context::SetThreadNum(int32_t thread_num) {
90 if (data_ == nullptr) {
91 MS_LOG(ERROR) << "Invalid context.";
92 return;
93 }
94 data_->thread_num = thread_num;
95 }
96
SetInterOpParallelNum(int32_t parallel_num)97 void Context::SetInterOpParallelNum(int32_t parallel_num) {
98 if (data_ == nullptr) {
99 MS_LOG(ERROR) << "Invalid context.";
100 return;
101 }
102 data_->inter_op_parallel_num_ = parallel_num;
103 }
104
SetGroupInfoFile(std::string group_info_file)105 void Context::SetGroupInfoFile(std::string group_info_file) {
106 if (data_ == nullptr) {
107 MS_LOG(ERROR) << "Invalid context.";
108 return;
109 }
110 data_->group_info_file_ = group_info_file;
111 }
112
GetGroupInfoFile() const113 std::string Context::GetGroupInfoFile() const {
114 if (data_ == nullptr) {
115 MS_LOG(ERROR) << "Invalid context.";
116 return "";
117 }
118 return data_->group_info_file_;
119 }
120
GetInterOpParallelNum() const121 int32_t Context::GetInterOpParallelNum() const {
122 if (data_ == nullptr) {
123 MS_LOG(ERROR) << "Invalid context.";
124 return 0;
125 }
126 return data_->inter_op_parallel_num_;
127 }
128
GetThreadNum() const129 int32_t Context::GetThreadNum() const {
130 if (data_ == nullptr) {
131 MS_LOG(ERROR) << "Invalid context.";
132 return 0;
133 }
134 return data_->thread_num;
135 }
136
SetEnableParallel(bool is_parallel)137 void Context::SetEnableParallel(bool is_parallel) {
138 if (data_ == nullptr) {
139 MS_LOG(ERROR) << "Invalid context.";
140 return;
141 }
142 data_->enable_parallel_ = is_parallel;
143 }
144
GetEnableParallel() const145 bool Context::GetEnableParallel() const {
146 if (data_ == nullptr) {
147 MS_LOG(ERROR) << "Invalid context.";
148 return false;
149 }
150 return data_->enable_parallel_;
151 }
152
SetThreadAffinity(int mode)153 void Context::SetThreadAffinity(int mode) {
154 if (data_ == nullptr) {
155 MS_LOG(ERROR) << "Invalid context.";
156 return;
157 }
158 if (mode < lite::NO_BIND || mode > lite::MID_CPU) {
159 MS_LOG(WARNING) << "Invalid thread affinity mode: " << mode << ", change to NO_BIND mode.";
160 data_->affinity_mode_ = lite::NO_BIND;
161 return;
162 }
163 data_->affinity_mode_ = mode;
164 return;
165 }
166
GetThreadAffinityMode() const167 int Context::GetThreadAffinityMode() const {
168 if (data_ == nullptr) {
169 MS_LOG(ERROR) << "Invalid context.";
170 return -1;
171 }
172 return data_->affinity_mode_;
173 }
174
SetThreadAffinity(const std::vector<int> & core_list)175 void Context::SetThreadAffinity(const std::vector<int> &core_list) {
176 if (data_ == nullptr) {
177 MS_LOG(ERROR) << "Invalid context.";
178 return;
179 }
180 data_->affinity_core_list_ = core_list;
181
182 return;
183 }
184
GetThreadAffinityCoreList() const185 std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
186 if (data_ == nullptr) {
187 MS_LOG(ERROR) << "Invalid context.";
188 return {};
189 }
190 return data_->affinity_core_list_;
191 }
SetBuiltInDelegate(DelegateMode mode)192 void Context::SetBuiltInDelegate(DelegateMode mode) {
193 if (data_ == nullptr) {
194 MS_LOG(ERROR) << "Invalid context.";
195 return;
196 }
197 if (mode < kCoreML || mode > kNNAPI) {
198 MS_LOG(WARNING) << "Invalid built-in delegate mode: " << mode << ", do not enable any delegate.";
199 data_->delegate_mode_ = kNoDelegate;
200 return;
201 }
202 data_->delegate_mode_ = mode;
203 return;
204 }
205
GetBuiltInDelegate() const206 DelegateMode Context::GetBuiltInDelegate() const {
207 if (data_ == nullptr) {
208 MS_LOG(ERROR) << "Invalid context.";
209 return kNoDelegate;
210 }
211 return data_->delegate_mode_;
212 }
213
get_delegate() const214 std::shared_ptr<AbstractDelegate> Context::get_delegate() const {
215 if (data_ == nullptr) {
216 MS_LOG(ERROR) << "Invalid context.";
217 return nullptr;
218 }
219 return data_->delegate;
220 }
221
222 // deprecated
SetDelegate(const std::shared_ptr<Delegate> & delegate)223 void Context::SetDelegate(const std::shared_ptr<Delegate> &delegate) {
224 if (data_ == nullptr) {
225 MS_LOG(ERROR) << "Invalid context.";
226 return;
227 }
228 data_->delegate = delegate;
229 }
230
231 // deprecated
GetDelegate() const232 std::shared_ptr<Delegate> Context::GetDelegate() const {
233 if (data_ == nullptr) {
234 MS_LOG(ERROR) << "Invalid context.";
235 return nullptr;
236 }
237 return data_->delegate;
238 }
239
SetMultiModalHW(bool float_mode)240 void Context::SetMultiModalHW(bool float_mode) {
241 if (data_ == nullptr) {
242 MS_LOG(ERROR) << "Invalid context.";
243 return;
244 }
245 data_->float_mode = float_mode;
246 }
247
GetMultiModalHW() const248 bool Context::GetMultiModalHW() const {
249 if (data_ == nullptr) {
250 MS_LOG(ERROR) << "Invalid context.";
251 return false;
252 }
253 return data_->float_mode;
254 }
255
MutableDeviceInfo()256 std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
257 static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
258 if (data_ == nullptr) {
259 MS_LOG(ERROR) << "Invalid context.";
260 return empty;
261 }
262 return data_->device_info_list;
263 }
264
DeviceInfoContext()265 DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}
266
GetProviderChar() const267 std::vector<char> DeviceInfoContext::GetProviderChar() const {
268 if (data_ == nullptr) {
269 MS_LOG(ERROR) << "Invalid context.";
270 return std::vector<char>();
271 }
272 const std::string &ref = GetValue<std::string>(data_, kModelOptionProvider);
273 return StringToChar(ref);
274 }
275
SetProvider(const std::vector<char> & provider)276 void DeviceInfoContext::SetProvider(const std::vector<char> &provider) {
277 if (data_ == nullptr) {
278 MS_LOG(ERROR) << "Invalid context.";
279 return;
280 }
281 data_->params[kModelOptionProvider] = CharToString(provider);
282 }
283
GetProviderDeviceChar() const284 std::vector<char> DeviceInfoContext::GetProviderDeviceChar() const {
285 if (data_ == nullptr) {
286 MS_LOG(ERROR) << "Invalid context.";
287 return std::vector<char>();
288 }
289 const std::string &ref = GetValue<std::string>(data_, kModelOptionProviderDevice);
290 return StringToChar(ref);
291 }
292
SetProviderDevice(const std::vector<char> & device)293 void DeviceInfoContext::SetProviderDevice(const std::vector<char> &device) {
294 if (data_ == nullptr) {
295 MS_LOG(ERROR) << "Invalid context.";
296 return;
297 }
298 data_->params[kModelOptionProviderDevice] = CharToString(device);
299 }
300
SetAllocator(const std::shared_ptr<Allocator> & allocator)301 void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
302 if (data_ == nullptr) {
303 MS_LOG(ERROR) << "Invalid context.";
304 return;
305 }
306 data_->allocator = allocator;
307 }
308
GetAllocator() const309 std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
310 if (data_ == nullptr) {
311 MS_LOG(ERROR) << "Invalid context.";
312 return nullptr;
313 }
314 return data_->allocator;
315 }
316
SetEnableFP16(bool is_fp16)317 void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
318 if (data_ == nullptr) {
319 MS_LOG(ERROR) << "Invalid context.";
320 return;
321 }
322 data_->params[kModelOptionCpuEnableFP16] = is_fp16;
323 }
324
GetEnableFP16() const325 bool CPUDeviceInfo::GetEnableFP16() const {
326 if (data_ == nullptr) {
327 MS_LOG(ERROR) << "Invalid context.";
328 return false;
329 }
330 return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
331 }
332
SetEnableFP16(bool is_fp16)333 void GPUDeviceInfo::SetEnableFP16(bool is_fp16) {
334 if (data_ == nullptr) {
335 MS_LOG(ERROR) << "Invalid context.";
336 return;
337 }
338 data_->params[kModelOptionGPUEnableFP16] = is_fp16;
339 }
340
GetEnableFP16() const341 bool GPUDeviceInfo::GetEnableFP16() const {
342 if (data_ == nullptr) {
343 MS_LOG(ERROR) << "Invalid context.";
344 return false;
345 }
346 return GetValue<bool>(data_, kModelOptionGPUEnableFP16);
347 }
348
SetEnableGLTexture(bool is_enable_gl_texture)349 void GPUDeviceInfo::SetEnableGLTexture(bool is_enable_gl_texture) {
350 if (data_ == nullptr) {
351 MS_LOG(ERROR) << "Invalid context.";
352 return;
353 }
354 data_->params[kModelOptionGPUEnableGLTexture] = is_enable_gl_texture;
355 }
356
GetEnableGLTexture() const357 bool GPUDeviceInfo::GetEnableGLTexture() const {
358 if (data_ == nullptr) {
359 MS_LOG(ERROR) << "Invalid context.";
360 return false;
361 }
362 return GetValue<bool>(data_, kModelOptionGPUEnableGLTexture);
363 }
364
SetGLContext(void * gl_context)365 void GPUDeviceInfo::SetGLContext(void *gl_context) {
366 if (data_ == nullptr) {
367 MS_LOG(ERROR) << "Invalid context.";
368 return;
369 }
370 data_->params[kModelOptionGPUGLContext] = gl_context;
371 }
372
GetGLContext() const373 void *GPUDeviceInfo::GetGLContext() const {
374 if (data_ == nullptr) {
375 MS_LOG(ERROR) << "Invalid context.";
376 return nullptr;
377 }
378 return GetValue<void *>(data_, kModelOptionGPUGLContext);
379 }
380
SetGLDisplay(void * gl_display)381 void GPUDeviceInfo::SetGLDisplay(void *gl_display) {
382 if (data_ == nullptr) {
383 MS_LOG(ERROR) << "Invalid context.";
384 return;
385 }
386 data_->params[kModelOptionGPUGLDisplay] = gl_display;
387 }
388
GetGLDisplay() const389 void *GPUDeviceInfo::GetGLDisplay() const {
390 if (data_ == nullptr) {
391 MS_LOG(ERROR) << "Invalid context.";
392 return nullptr;
393 }
394 return GetValue<void *>(data_, kModelOptionGPUGLDisplay);
395 }
396
SetEnableFP16(bool is_fp16)397 void KirinNPUDeviceInfo::SetEnableFP16(bool is_fp16) {
398 if (data_ == nullptr) {
399 MS_LOG(ERROR) << "Invalid context.";
400 return;
401 }
402 data_->params[kModelOptionNPUEnableFP16] = is_fp16;
403 }
404
GetEnableFP16() const405 bool KirinNPUDeviceInfo::GetEnableFP16() const {
406 if (data_ == nullptr) {
407 MS_LOG(ERROR) << "Invalid context.";
408 return false;
409 }
410 return GetValue<bool>(data_, kModelOptionNPUEnableFP16);
411 }
412
SetFrequency(int frequency)413 void KirinNPUDeviceInfo::SetFrequency(int frequency) {
414 if (data_ == nullptr) {
415 MS_LOG(ERROR) << "Invalid context.";
416 return;
417 }
418 data_->params[kModelOptionKirinNpuFrequency] = frequency;
419 }
420
GetFrequency() const421 int KirinNPUDeviceInfo::GetFrequency() const {
422 if (data_ == nullptr) {
423 MS_LOG(ERROR) << "Invalid context.";
424 return 0;
425 }
426 return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
427 }
428
SetDeviceID(uint32_t device_id)429 void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
430 if (data_ == nullptr) {
431 MS_LOG(ERROR) << "Invalid context.";
432 return;
433 }
434 data_->params[kModelOptionGPUDeviceID] = device_id;
435 }
436
GetDeviceID() const437 uint32_t GPUDeviceInfo::GetDeviceID() const {
438 if (data_ == nullptr) {
439 MS_LOG(ERROR) << "Invalid context.";
440 return 0;
441 }
442 return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
443 }
444
GetRankID() const445 int GPUDeviceInfo::GetRankID() const {
446 #ifdef SUPPORT_TENSORRT
447 data_->params[kModelOptionGPURankID] = lite::GetRankID();
448 #else
449 data_->params[kModelOptionGPURankID] = 0;
450 #endif
451 return GetValue<int>(data_, kModelOptionGPURankID);
452 }
453
GetGroupSize() const454 int GPUDeviceInfo::GetGroupSize() const {
455 #ifdef SUPPORT_TENSORRT
456 data_->params[kModelOptionGPUGroupSize] = lite::GetGPUGroupSize();
457 #else
458 data_->params[kModelOptionGPUGroupSize] = 1;
459 #endif
460 return GetValue<int>(data_, kModelOptionGPUGroupSize);
461 }
462
SetPrecisionMode(const std::vector<char> & precision_mode)463 void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
464 MS_LOG(ERROR) << "Unsupported Feature.";
465 }
GetPrecisionModeChar() const466 std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
467 MS_LOG(ERROR) << "Unsupported Feature.";
468 std::vector<char> ret;
469 return ret;
470 }
471
SetDeviceID(uint32_t device_id)472 void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
473 if (data_ == nullptr) {
474 MS_LOG(ERROR) << "Invalid context.";
475 return;
476 }
477 data_->params[kModelOptionAscendDeviceID] = device_id;
478 }
479
GetDeviceID() const480 uint32_t AscendDeviceInfo::GetDeviceID() const {
481 if (data_ == nullptr) {
482 MS_LOG(ERROR) << "Invalid context.";
483 return 0;
484 }
485 return GetValue<uint32_t>(data_, kModelOptionAscendDeviceID);
486 }
487
SetRankID(uint32_t rank_id)488 void AscendDeviceInfo::SetRankID(uint32_t rank_id) {
489 if (data_ == nullptr) {
490 MS_LOG(ERROR) << "Invalid context.";
491 return;
492 }
493 data_->params[kModelOptionAscendRankID] = rank_id;
494 }
495
GetRankID() const496 uint32_t AscendDeviceInfo::GetRankID() const {
497 if (data_ == nullptr) {
498 MS_LOG(ERROR) << "Invalid context.";
499 return 0;
500 }
501 return GetValue<uint32_t>(data_, kModelOptionAscendRankID);
502 }
503
SetInsertOpConfigPath(const std::vector<char> & cfg_path)504 void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
505 if (data_ == nullptr) {
506 MS_LOG(ERROR) << "Invalid context.";
507 return;
508 }
509 data_->params[kModelOptionAscendInsertOpCfgPath] = CharToString(cfg_path);
510 }
GetInsertOpConfigPathChar() const511 std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
512 if (data_ == nullptr) {
513 MS_LOG(ERROR) << "Invalid context.";
514 return std::vector<char>();
515 }
516 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInsertOpCfgPath);
517 return StringToChar(ref);
518 }
519
SetInputFormat(const std::vector<char> & format)520 void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
521 if (data_ == nullptr) {
522 MS_LOG(ERROR) << "Invalid context.";
523 return;
524 }
525 data_->params[kModelOptionAscendInputFormat] = CharToString(format);
526 }
527
GetInputFormatChar() const528 std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
529 if (data_ == nullptr) {
530 MS_LOG(ERROR) << "Invalid context.";
531 return std::vector<char>();
532 }
533 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputFormat);
534 return StringToChar(ref);
535 }
536
SetInputShape(const std::vector<char> & shape)537 void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
538 if (data_ == nullptr) {
539 MS_LOG(ERROR) << "Invalid context.";
540 return;
541 }
542 data_->params[kModelOptionAscendInputShape] = CharToString(shape);
543 }
GetInputShapeChar() const544 std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
545 if (data_ == nullptr) {
546 MS_LOG(ERROR) << "Invalid context.";
547 return std::vector<char>();
548 }
549 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputShape);
550 return StringToChar(ref);
551 }
552
SetDynamicBatchSize(const std::vector<size_t> & dynamic_batch_size)553 void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
554 if (data_ == nullptr) {
555 MS_LOG(ERROR) << "Invalid context.";
556 return;
557 }
558 std::string batchs;
559 for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
560 if (i != 0) {
561 batchs.push_back(',');
562 }
563 batchs += std::to_string(dynamic_batch_size[i]);
564 }
565 data_->params[kModelOptionAscendDynamicBatchSize] = batchs;
566 }
567
GetDynamicBatchSizeChar() const568 std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
569 if (data_ == nullptr) {
570 MS_LOG(ERROR) << "Invalid context.";
571 return std::vector<char>();
572 }
573 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicBatchSize);
574 return StringToChar(ref);
575 }
576
SetDynamicImageSize(const std::vector<char> & dynamic_image_size)577 void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
578 if (data_ == nullptr) {
579 MS_LOG(ERROR) << "Invalid context.";
580 return;
581 }
582 data_->params[kModelOptionAscendDynamicImageSize] = CharToString(dynamic_image_size);
583 }
584
GetDynamicImageSizeChar() const585 std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
586 if (data_ == nullptr) {
587 MS_LOG(ERROR) << "Invalid context.";
588 return std::vector<char>();
589 }
590 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicImageSize);
591 return StringToChar(ref);
592 }
593
SetPrecisionMode(const std::vector<char> & precision_mode)594 void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
595 if (data_ == nullptr) {
596 MS_LOG(ERROR) << "Invalid context.";
597 return;
598 }
599 data_->params[kModelOptionAscendPrecisionMode] = CharToString(precision_mode);
600 }
601
GetPrecisionModeChar() const602 std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
603 if (data_ == nullptr) {
604 MS_LOG(ERROR) << "Invalid context.";
605 return std::vector<char>();
606 }
607 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendPrecisionMode);
608 return StringToChar(ref);
609 }
610
SetOpSelectImplMode(const std::vector<char> & op_select_impl_mode)611 void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
612 if (data_ == nullptr) {
613 MS_LOG(ERROR) << "Invalid context.";
614 return;
615 }
616 data_->params[kModelOptionAscendOpSelectImplMode] = CharToString(op_select_impl_mode);
617 }
618
GetOpSelectImplModeChar() const619 std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
620 if (data_ == nullptr) {
621 MS_LOG(ERROR) << "Invalid context.";
622 return std::vector<char>();
623 }
624 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendOpSelectImplMode);
625 return StringToChar(ref);
626 }
627
SetFusionSwitchConfigPath(const std::vector<char> & cfg_path)628 void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
629 if (data_ == nullptr) {
630 MS_LOG(ERROR) << "Invalid context.";
631 return;
632 }
633 data_->params[KModelOptionAscendFusionSwitchCfgPath] = CharToString(cfg_path);
634 }
GetFusionSwitchConfigPathChar() const635 std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
636 if (data_ == nullptr) {
637 MS_LOG(ERROR) << "Invalid context.";
638 return std::vector<char>();
639 }
640 const std::string &ref = GetValue<std::string>(data_, KModelOptionAscendFusionSwitchCfgPath);
641 return StringToChar(ref);
642 }
643
SetInputShapeMap(const std::map<int,std::vector<int>> & shape)644 void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
645 if (data_ == nullptr) {
646 MS_LOG(ERROR) << "Invalid context.";
647 return;
648 }
649 data_->params[kModelOptionAscendInputShapeMap] = shape;
650 }
651
GetInputShapeMap() const652 std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
653 if (data_ == nullptr) {
654 MS_LOG(ERROR) << "Invalid context.";
655 return std::map<int, std::vector<int>>();
656 }
657 return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscendInputShapeMap);
658 }
659
SetOutputType(enum DataType output_type)660 void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
661 if (data_ == nullptr) {
662 MS_LOG(ERROR) << "Invalid context.";
663 return;
664 }
665 data_->params[kModelOptionAscendOutputType] = output_type;
666 }
667
GetOutputType() const668 enum DataType AscendDeviceInfo::GetOutputType() const {
669 if (data_ == nullptr) {
670 MS_LOG(ERROR) << "Invalid context.";
671 return DataType::kTypeUnknown;
672 }
673 return GetValue<enum DataType>(data_, kModelOptionAscendOutputType);
674 }
675
SetBufferOptimizeMode(const std::vector<char> & buffer_optimize_mode)676 void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
677 if (data_ == nullptr) {
678 MS_LOG(ERROR) << "Invalid context.";
679 return;
680 }
681 data_->params[kModelOptionAscendBufferOptimize] = CharToString(buffer_optimize_mode);
682 }
683
GetBufferOptimizeModeChar() const684 std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
685 if (data_ == nullptr) {
686 MS_LOG(ERROR) << "Invalid context.";
687 return std::vector<char>();
688 }
689 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize);
690 return StringToChar(ref);
691 }
692
SetDeviceID(size_t device_id)693 void NNRTDeviceInfo::SetDeviceID(size_t device_id) {
694 if (data_ == nullptr) {
695 MS_LOG(ERROR) << "Invalid context.";
696 return;
697 }
698 data_->params[kModelOptionNNRTDeviceID] = device_id;
699 }
700
GetDeviceID() const701 size_t NNRTDeviceInfo::GetDeviceID() const {
702 if (data_ == nullptr) {
703 MS_LOG(ERROR) << "Invalid context.";
704 return 0;
705 }
706 return GetValue<size_t>(data_, kModelOptionNNRTDeviceID);
707 }
708
SetPerformanceMode(int performance_mode)709 void NNRTDeviceInfo::SetPerformanceMode(int performance_mode) {
710 if (data_ == nullptr) {
711 MS_LOG(ERROR) << "Invalid context.";
712 return;
713 }
714 data_->params[kModelOptionNNRTPerformanceMode] = performance_mode;
715 }
716
GetPerformanceMode() const717 int NNRTDeviceInfo::GetPerformanceMode() const {
718 if (data_ == nullptr) {
719 MS_LOG(ERROR) << "Invalid context.";
720 return 0;
721 }
722 return GetValue<int>(data_, kModelOptionNNRTPerformanceMode);
723 }
724
SetPriority(int priority)725 void NNRTDeviceInfo::SetPriority(int priority) {
726 if (data_ == nullptr) {
727 MS_LOG(ERROR) << "Invalid context.";
728 return;
729 }
730 data_->params[kModelOptionNNRTPriority] = priority;
731 }
732
GetPriority() const733 int NNRTDeviceInfo::GetPriority() const {
734 if (data_ == nullptr) {
735 MS_LOG(ERROR) << "Invalid context.";
736 return 0;
737 }
738 return GetValue<int>(data_, kModelOptionNNRTPriority);
739 }
740
SetEnableFP16(bool is_fp16)741 void NNRTDeviceInfo::SetEnableFP16(bool is_fp16) {
742 if (data_ == nullptr) {
743 MS_LOG(ERROR) << "Invalid context.";
744 return;
745 }
746 data_->params[kModelOptionNNRTEnableFP16] = is_fp16;
747 }
748
GetEnableFP16() const749 bool NNRTDeviceInfo::GetEnableFP16() const {
750 if (data_ == nullptr) {
751 MS_LOG(ERROR) << "Invalid context.";
752 return false;
753 }
754 return GetValue<bool>(data_, kModelOptionNNRTEnableFP16);
755 }
756
SetExtensions(const std::vector<Extension> & extensions)757 void NNRTDeviceInfo::SetExtensions(const std::vector<Extension> &extensions) {
758 if (data_ == nullptr) {
759 MS_LOG(ERROR) << "Invalid context.";
760 return;
761 }
762 data_->params[kModelOptionNNRTExtensions] = extensions;
763 }
764
GetExtensions() const765 std::vector<Extension> NNRTDeviceInfo::GetExtensions() const {
766 if (data_ == nullptr) {
767 MS_LOG(ERROR) << "Invalid context.";
768 return {};
769 }
770 return GetValue<std::vector<Extension>>(data_, kModelOptionNNRTExtensions);
771 }
772 } // namespace mindspore
773