• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/litert/cxx_api/context.h"
17 #include <string>
18 #include <memory>
19 #include "include/api/types.h"
20 #include "include/api/data_type.h"
21 #include "include/lite_types.h"
22 #include "src/litert/inner_allocator.h"
23 #include "src/common/log_adapter.h"
24 #include "src/extendrt/delegate/tensorrt/distribution/distribution_base.h"
25 
26 namespace mindspore {
27 constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
28 constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
29 constexpr auto kModelOptionNPUEnableFP16 = "mindspore.option.npu.enable_fp16";
30 constexpr auto kModelOptionGPUEnableGLTexture = "mindspore.option.gpu.enable_gl_texture_";
31 constexpr auto kModelOptionGPUGLContext = "mindspore.option.gpu.gl_context_";
32 constexpr auto kModelOptionGPUGLDisplay = "mindspore.option.gpu.gl_display_";
33 constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id";
34 constexpr auto kModelOptionGPURankID = "mindspore.option.gpu.rank_id";
35 constexpr auto kModelOptionGPUGroupSize = "mindspore.option.gpu.group_size";
36 constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
37 constexpr auto kModelOptionProvider = "mindspore.option.provider";
38 constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
39 constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
40 constexpr auto kModelOptionAscendDeviceID = kModelOptionDeviceID;
41 constexpr auto kModelOptionAscendInsertOpCfgPath = "mindspore.option.ascend.insert_op_config_file_path";
42 constexpr auto kModelOptionAscendInputFormat = "mindspore.option.ascend.input_format";
43 constexpr auto kModelOptionAscendInputShapeMap = "mindspore.option.ascend.input_shape_map";
44 constexpr auto kModelOptionAscendInputShape = "mindspore.option.ascend.input_shape";
45 constexpr auto kModelOptionAscendOutputType = "mindspore.option.ascend.output_type";
46 constexpr auto kModelOptionAscendPrecisionMode = "mindspore.option.ascend.precision_mode";
47 constexpr auto kModelOptionAscendOpSelectImplMode = "mindspore.option.ascend.op_select_impl_mode";
48 constexpr auto KModelOptionAscendFusionSwitchCfgPath = "mindspore.option.ascend.fusion_switch_config_file_path";
49 constexpr auto kModelOptionAscendDynamicBatchSize = "mindspore.option.ascend.dynamic_batch_size";
50 constexpr auto kModelOptionAscendDynamicImageSize = "mindspore.option.ascend.dynamic_image_size";
51 constexpr auto kModelOptionAscendBufferOptimize = "mindspore.option.ascend.buffer_optimize";
52 constexpr auto kModelOptionAscendRankID = "mindspore.option.ascend.rank_id";
53 constexpr auto kModelOptionNNRTDeviceID = "mindspore.option.nnrt.device_id";
54 constexpr auto kModelOptionNNRTPerformanceMode = "mindspore.option.nnrt.performance_mode";
55 constexpr auto kModelOptionNNRTPriority = "mindspore.option.nnrt.priority";
56 constexpr auto kModelOptionNNRTEnableFP16 = "mindspore.option.nnrt.enable_fp16";
57 constexpr auto kModelOptionNNRTExtensions = "mindspore.option.nnrt.extensions";
58 #ifdef USE_GLOG
59 extern "C" {
60 extern void mindspore_log_init();
61 }
62 #endif
63 
Context()64 Context::Context() : data_(std::make_shared<Data>()) {
65 #ifdef USE_GLOG
66   mindspore::mindspore_log_init();
67 #endif
68 }
69 
70 template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
GetValue(const std::shared_ptr<DeviceInfoContext::Data> & data,const std::string & key)71 static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
72   static U empty_result;
73   if (data == nullptr) {
74     return empty_result;
75   }
76   auto iter = data->params.find(key);
77   if (iter == data->params.end()) {
78     return empty_result;
79   }
80 #ifndef SUPPORT_NNIE
81   const std::any &value = iter->second;
82   return std::any_cast<const U &>(value);
83 #else
84   const std::experimental::any &value = iter->second;
85   return std::experimental::any_cast<const U &>(value);
86 #endif
87 }
88 
SetThreadNum(int32_t thread_num)89 void Context::SetThreadNum(int32_t thread_num) {
90   if (data_ == nullptr) {
91     MS_LOG(ERROR) << "Invalid context.";
92     return;
93   }
94   data_->thread_num = thread_num;
95 }
96 
SetInterOpParallelNum(int32_t parallel_num)97 void Context::SetInterOpParallelNum(int32_t parallel_num) {
98   if (data_ == nullptr) {
99     MS_LOG(ERROR) << "Invalid context.";
100     return;
101   }
102   data_->inter_op_parallel_num_ = parallel_num;
103 }
104 
SetGroupInfoFile(std::string group_info_file)105 void Context::SetGroupInfoFile(std::string group_info_file) {
106   if (data_ == nullptr) {
107     MS_LOG(ERROR) << "Invalid context.";
108     return;
109   }
110   data_->group_info_file_ = group_info_file;
111 }
112 
GetGroupInfoFile() const113 std::string Context::GetGroupInfoFile() const {
114   if (data_ == nullptr) {
115     MS_LOG(ERROR) << "Invalid context.";
116     return "";
117   }
118   return data_->group_info_file_;
119 }
120 
GetInterOpParallelNum() const121 int32_t Context::GetInterOpParallelNum() const {
122   if (data_ == nullptr) {
123     MS_LOG(ERROR) << "Invalid context.";
124     return 0;
125   }
126   return data_->inter_op_parallel_num_;
127 }
128 
GetThreadNum() const129 int32_t Context::GetThreadNum() const {
130   if (data_ == nullptr) {
131     MS_LOG(ERROR) << "Invalid context.";
132     return 0;
133   }
134   return data_->thread_num;
135 }
136 
SetEnableParallel(bool is_parallel)137 void Context::SetEnableParallel(bool is_parallel) {
138   if (data_ == nullptr) {
139     MS_LOG(ERROR) << "Invalid context.";
140     return;
141   }
142   data_->enable_parallel_ = is_parallel;
143 }
144 
GetEnableParallel() const145 bool Context::GetEnableParallel() const {
146   if (data_ == nullptr) {
147     MS_LOG(ERROR) << "Invalid context.";
148     return false;
149   }
150   return data_->enable_parallel_;
151 }
152 
SetThreadAffinity(int mode)153 void Context::SetThreadAffinity(int mode) {
154   if (data_ == nullptr) {
155     MS_LOG(ERROR) << "Invalid context.";
156     return;
157   }
158   if (mode < lite::NO_BIND || mode > lite::MID_CPU) {
159     MS_LOG(WARNING) << "Invalid thread affinity mode: " << mode << ", change to NO_BIND mode.";
160     data_->affinity_mode_ = lite::NO_BIND;
161     return;
162   }
163   data_->affinity_mode_ = mode;
164   return;
165 }
166 
GetThreadAffinityMode() const167 int Context::GetThreadAffinityMode() const {
168   if (data_ == nullptr) {
169     MS_LOG(ERROR) << "Invalid context.";
170     return -1;
171   }
172   return data_->affinity_mode_;
173 }
174 
SetThreadAffinity(const std::vector<int> & core_list)175 void Context::SetThreadAffinity(const std::vector<int> &core_list) {
176   if (data_ == nullptr) {
177     MS_LOG(ERROR) << "Invalid context.";
178     return;
179   }
180   data_->affinity_core_list_ = core_list;
181 
182   return;
183 }
184 
GetThreadAffinityCoreList() const185 std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
186   if (data_ == nullptr) {
187     MS_LOG(ERROR) << "Invalid context.";
188     return {};
189   }
190   return data_->affinity_core_list_;
191 }
SetBuiltInDelegate(DelegateMode mode)192 void Context::SetBuiltInDelegate(DelegateMode mode) {
193   if (data_ == nullptr) {
194     MS_LOG(ERROR) << "Invalid context.";
195     return;
196   }
197   if (mode < kCoreML || mode > kNNAPI) {
198     MS_LOG(WARNING) << "Invalid built-in delegate mode: " << mode << ", do not enable any delegate.";
199     data_->delegate_mode_ = kNoDelegate;
200     return;
201   }
202   data_->delegate_mode_ = mode;
203   return;
204 }
205 
GetBuiltInDelegate() const206 DelegateMode Context::GetBuiltInDelegate() const {
207   if (data_ == nullptr) {
208     MS_LOG(ERROR) << "Invalid context.";
209     return kNoDelegate;
210   }
211   return data_->delegate_mode_;
212 }
213 
get_delegate() const214 std::shared_ptr<AbstractDelegate> Context::get_delegate() const {
215   if (data_ == nullptr) {
216     MS_LOG(ERROR) << "Invalid context.";
217     return nullptr;
218   }
219   return data_->delegate;
220 }
221 
222 // deprecated
SetDelegate(const std::shared_ptr<Delegate> & delegate)223 void Context::SetDelegate(const std::shared_ptr<Delegate> &delegate) {
224   if (data_ == nullptr) {
225     MS_LOG(ERROR) << "Invalid context.";
226     return;
227   }
228   data_->delegate = delegate;
229 }
230 
231 // deprecated
GetDelegate() const232 std::shared_ptr<Delegate> Context::GetDelegate() const {
233   if (data_ == nullptr) {
234     MS_LOG(ERROR) << "Invalid context.";
235     return nullptr;
236   }
237   return data_->delegate;
238 }
239 
SetMultiModalHW(bool float_mode)240 void Context::SetMultiModalHW(bool float_mode) {
241   if (data_ == nullptr) {
242     MS_LOG(ERROR) << "Invalid context.";
243     return;
244   }
245   data_->float_mode = float_mode;
246 }
247 
GetMultiModalHW() const248 bool Context::GetMultiModalHW() const {
249   if (data_ == nullptr) {
250     MS_LOG(ERROR) << "Invalid context.";
251     return false;
252   }
253   return data_->float_mode;
254 }
255 
MutableDeviceInfo()256 std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
257   static std::vector<std::shared_ptr<DeviceInfoContext>> empty{};
258   if (data_ == nullptr) {
259     MS_LOG(ERROR) << "Invalid context.";
260     return empty;
261   }
262   return data_->device_info_list;
263 }
264 
DeviceInfoContext()265 DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}
266 
GetProviderChar() const267 std::vector<char> DeviceInfoContext::GetProviderChar() const {
268   if (data_ == nullptr) {
269     MS_LOG(ERROR) << "Invalid context.";
270     return std::vector<char>();
271   }
272   const std::string &ref = GetValue<std::string>(data_, kModelOptionProvider);
273   return StringToChar(ref);
274 }
275 
SetProvider(const std::vector<char> & provider)276 void DeviceInfoContext::SetProvider(const std::vector<char> &provider) {
277   if (data_ == nullptr) {
278     MS_LOG(ERROR) << "Invalid context.";
279     return;
280   }
281   data_->params[kModelOptionProvider] = CharToString(provider);
282 }
283 
GetProviderDeviceChar() const284 std::vector<char> DeviceInfoContext::GetProviderDeviceChar() const {
285   if (data_ == nullptr) {
286     MS_LOG(ERROR) << "Invalid context.";
287     return std::vector<char>();
288   }
289   const std::string &ref = GetValue<std::string>(data_, kModelOptionProviderDevice);
290   return StringToChar(ref);
291 }
292 
SetProviderDevice(const std::vector<char> & device)293 void DeviceInfoContext::SetProviderDevice(const std::vector<char> &device) {
294   if (data_ == nullptr) {
295     MS_LOG(ERROR) << "Invalid context.";
296     return;
297   }
298   data_->params[kModelOptionProviderDevice] = CharToString(device);
299 }
300 
SetAllocator(const std::shared_ptr<Allocator> & allocator)301 void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
302   if (data_ == nullptr) {
303     MS_LOG(ERROR) << "Invalid context.";
304     return;
305   }
306   data_->allocator = allocator;
307 }
308 
GetAllocator() const309 std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
310   if (data_ == nullptr) {
311     MS_LOG(ERROR) << "Invalid context.";
312     return nullptr;
313   }
314   return data_->allocator;
315 }
316 
SetEnableFP16(bool is_fp16)317 void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
318   if (data_ == nullptr) {
319     MS_LOG(ERROR) << "Invalid context.";
320     return;
321   }
322   data_->params[kModelOptionCpuEnableFP16] = is_fp16;
323 }
324 
GetEnableFP16() const325 bool CPUDeviceInfo::GetEnableFP16() const {
326   if (data_ == nullptr) {
327     MS_LOG(ERROR) << "Invalid context.";
328     return false;
329   }
330   return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
331 }
332 
SetEnableFP16(bool is_fp16)333 void GPUDeviceInfo::SetEnableFP16(bool is_fp16) {
334   if (data_ == nullptr) {
335     MS_LOG(ERROR) << "Invalid context.";
336     return;
337   }
338   data_->params[kModelOptionGPUEnableFP16] = is_fp16;
339 }
340 
GetEnableFP16() const341 bool GPUDeviceInfo::GetEnableFP16() const {
342   if (data_ == nullptr) {
343     MS_LOG(ERROR) << "Invalid context.";
344     return false;
345   }
346   return GetValue<bool>(data_, kModelOptionGPUEnableFP16);
347 }
348 
SetEnableGLTexture(bool is_enable_gl_texture)349 void GPUDeviceInfo::SetEnableGLTexture(bool is_enable_gl_texture) {
350   if (data_ == nullptr) {
351     MS_LOG(ERROR) << "Invalid context.";
352     return;
353   }
354   data_->params[kModelOptionGPUEnableGLTexture] = is_enable_gl_texture;
355 }
356 
GetEnableGLTexture() const357 bool GPUDeviceInfo::GetEnableGLTexture() const {
358   if (data_ == nullptr) {
359     MS_LOG(ERROR) << "Invalid context.";
360     return false;
361   }
362   return GetValue<bool>(data_, kModelOptionGPUEnableGLTexture);
363 }
364 
SetGLContext(void * gl_context)365 void GPUDeviceInfo::SetGLContext(void *gl_context) {
366   if (data_ == nullptr) {
367     MS_LOG(ERROR) << "Invalid context.";
368     return;
369   }
370   data_->params[kModelOptionGPUGLContext] = gl_context;
371 }
372 
GetGLContext() const373 void *GPUDeviceInfo::GetGLContext() const {
374   if (data_ == nullptr) {
375     MS_LOG(ERROR) << "Invalid context.";
376     return nullptr;
377   }
378   return GetValue<void *>(data_, kModelOptionGPUGLContext);
379 }
380 
SetGLDisplay(void * gl_display)381 void GPUDeviceInfo::SetGLDisplay(void *gl_display) {
382   if (data_ == nullptr) {
383     MS_LOG(ERROR) << "Invalid context.";
384     return;
385   }
386   data_->params[kModelOptionGPUGLDisplay] = gl_display;
387 }
388 
GetGLDisplay() const389 void *GPUDeviceInfo::GetGLDisplay() const {
390   if (data_ == nullptr) {
391     MS_LOG(ERROR) << "Invalid context.";
392     return nullptr;
393   }
394   return GetValue<void *>(data_, kModelOptionGPUGLDisplay);
395 }
396 
SetEnableFP16(bool is_fp16)397 void KirinNPUDeviceInfo::SetEnableFP16(bool is_fp16) {
398   if (data_ == nullptr) {
399     MS_LOG(ERROR) << "Invalid context.";
400     return;
401   }
402   data_->params[kModelOptionNPUEnableFP16] = is_fp16;
403 }
404 
GetEnableFP16() const405 bool KirinNPUDeviceInfo::GetEnableFP16() const {
406   if (data_ == nullptr) {
407     MS_LOG(ERROR) << "Invalid context.";
408     return false;
409   }
410   return GetValue<bool>(data_, kModelOptionNPUEnableFP16);
411 }
412 
SetFrequency(int frequency)413 void KirinNPUDeviceInfo::SetFrequency(int frequency) {
414   if (data_ == nullptr) {
415     MS_LOG(ERROR) << "Invalid context.";
416     return;
417   }
418   data_->params[kModelOptionKirinNpuFrequency] = frequency;
419 }
420 
GetFrequency() const421 int KirinNPUDeviceInfo::GetFrequency() const {
422   if (data_ == nullptr) {
423     MS_LOG(ERROR) << "Invalid context.";
424     return 0;
425   }
426   return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
427 }
428 
SetDeviceID(uint32_t device_id)429 void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
430   if (data_ == nullptr) {
431     MS_LOG(ERROR) << "Invalid context.";
432     return;
433   }
434   data_->params[kModelOptionGPUDeviceID] = device_id;
435 }
436 
GetDeviceID() const437 uint32_t GPUDeviceInfo::GetDeviceID() const {
438   if (data_ == nullptr) {
439     MS_LOG(ERROR) << "Invalid context.";
440     return 0;
441   }
442   return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
443 }
444 
GetRankID() const445 int GPUDeviceInfo::GetRankID() const {
446 #ifdef SUPPORT_TENSORRT
447   data_->params[kModelOptionGPURankID] = lite::GetRankID();
448 #else
449   data_->params[kModelOptionGPURankID] = 0;
450 #endif
451   return GetValue<int>(data_, kModelOptionGPURankID);
452 }
453 
GetGroupSize() const454 int GPUDeviceInfo::GetGroupSize() const {
455 #ifdef SUPPORT_TENSORRT
456   data_->params[kModelOptionGPUGroupSize] = lite::GetGPUGroupSize();
457 #else
458   data_->params[kModelOptionGPUGroupSize] = 1;
459 #endif
460   return GetValue<int>(data_, kModelOptionGPUGroupSize);
461 }
462 
SetPrecisionMode(const std::vector<char> & precision_mode)463 void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
464   MS_LOG(ERROR) << "Unsupported Feature.";
465 }
GetPrecisionModeChar() const466 std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
467   MS_LOG(ERROR) << "Unsupported Feature.";
468   std::vector<char> ret;
469   return ret;
470 }
471 
SetDeviceID(uint32_t device_id)472 void AscendDeviceInfo::SetDeviceID(uint32_t device_id) {
473   if (data_ == nullptr) {
474     MS_LOG(ERROR) << "Invalid context.";
475     return;
476   }
477   data_->params[kModelOptionAscendDeviceID] = device_id;
478 }
479 
GetDeviceID() const480 uint32_t AscendDeviceInfo::GetDeviceID() const {
481   if (data_ == nullptr) {
482     MS_LOG(ERROR) << "Invalid context.";
483     return 0;
484   }
485   return GetValue<uint32_t>(data_, kModelOptionAscendDeviceID);
486 }
487 
SetRankID(uint32_t rank_id)488 void AscendDeviceInfo::SetRankID(uint32_t rank_id) {
489   if (data_ == nullptr) {
490     MS_LOG(ERROR) << "Invalid context.";
491     return;
492   }
493   data_->params[kModelOptionAscendRankID] = rank_id;
494 }
495 
GetRankID() const496 uint32_t AscendDeviceInfo::GetRankID() const {
497   if (data_ == nullptr) {
498     MS_LOG(ERROR) << "Invalid context.";
499     return 0;
500   }
501   return GetValue<uint32_t>(data_, kModelOptionAscendRankID);
502 }
503 
SetInsertOpConfigPath(const std::vector<char> & cfg_path)504 void AscendDeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
505   if (data_ == nullptr) {
506     MS_LOG(ERROR) << "Invalid context.";
507     return;
508   }
509   data_->params[kModelOptionAscendInsertOpCfgPath] = CharToString(cfg_path);
510 }
GetInsertOpConfigPathChar() const511 std::vector<char> AscendDeviceInfo::GetInsertOpConfigPathChar() const {
512   if (data_ == nullptr) {
513     MS_LOG(ERROR) << "Invalid context.";
514     return std::vector<char>();
515   }
516   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInsertOpCfgPath);
517   return StringToChar(ref);
518 }
519 
SetInputFormat(const std::vector<char> & format)520 void AscendDeviceInfo::SetInputFormat(const std::vector<char> &format) {
521   if (data_ == nullptr) {
522     MS_LOG(ERROR) << "Invalid context.";
523     return;
524   }
525   data_->params[kModelOptionAscendInputFormat] = CharToString(format);
526 }
527 
GetInputFormatChar() const528 std::vector<char> AscendDeviceInfo::GetInputFormatChar() const {
529   if (data_ == nullptr) {
530     MS_LOG(ERROR) << "Invalid context.";
531     return std::vector<char>();
532   }
533   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputFormat);
534   return StringToChar(ref);
535 }
536 
SetInputShape(const std::vector<char> & shape)537 void AscendDeviceInfo::SetInputShape(const std::vector<char> &shape) {
538   if (data_ == nullptr) {
539     MS_LOG(ERROR) << "Invalid context.";
540     return;
541   }
542   data_->params[kModelOptionAscendInputShape] = CharToString(shape);
543 }
GetInputShapeChar() const544 std::vector<char> AscendDeviceInfo::GetInputShapeChar() const {
545   if (data_ == nullptr) {
546     MS_LOG(ERROR) << "Invalid context.";
547     return std::vector<char>();
548   }
549   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendInputShape);
550   return StringToChar(ref);
551 }
552 
SetDynamicBatchSize(const std::vector<size_t> & dynamic_batch_size)553 void AscendDeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
554   if (data_ == nullptr) {
555     MS_LOG(ERROR) << "Invalid context.";
556     return;
557   }
558   std::string batchs;
559   for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
560     if (i != 0) {
561       batchs.push_back(',');
562     }
563     batchs += std::to_string(dynamic_batch_size[i]);
564   }
565   data_->params[kModelOptionAscendDynamicBatchSize] = batchs;
566 }
567 
GetDynamicBatchSizeChar() const568 std::vector<char> AscendDeviceInfo::GetDynamicBatchSizeChar() const {
569   if (data_ == nullptr) {
570     MS_LOG(ERROR) << "Invalid context.";
571     return std::vector<char>();
572   }
573   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicBatchSize);
574   return StringToChar(ref);
575 }
576 
SetDynamicImageSize(const std::vector<char> & dynamic_image_size)577 void AscendDeviceInfo::SetDynamicImageSize(const std::vector<char> &dynamic_image_size) {
578   if (data_ == nullptr) {
579     MS_LOG(ERROR) << "Invalid context.";
580     return;
581   }
582   data_->params[kModelOptionAscendDynamicImageSize] = CharToString(dynamic_image_size);
583 }
584 
GetDynamicImageSizeChar() const585 std::vector<char> AscendDeviceInfo::GetDynamicImageSizeChar() const {
586   if (data_ == nullptr) {
587     MS_LOG(ERROR) << "Invalid context.";
588     return std::vector<char>();
589   }
590   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendDynamicImageSize);
591   return StringToChar(ref);
592 }
593 
SetPrecisionMode(const std::vector<char> & precision_mode)594 void AscendDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
595   if (data_ == nullptr) {
596     MS_LOG(ERROR) << "Invalid context.";
597     return;
598   }
599   data_->params[kModelOptionAscendPrecisionMode] = CharToString(precision_mode);
600 }
601 
GetPrecisionModeChar() const602 std::vector<char> AscendDeviceInfo::GetPrecisionModeChar() const {
603   if (data_ == nullptr) {
604     MS_LOG(ERROR) << "Invalid context.";
605     return std::vector<char>();
606   }
607   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendPrecisionMode);
608   return StringToChar(ref);
609 }
610 
SetOpSelectImplMode(const std::vector<char> & op_select_impl_mode)611 void AscendDeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
612   if (data_ == nullptr) {
613     MS_LOG(ERROR) << "Invalid context.";
614     return;
615   }
616   data_->params[kModelOptionAscendOpSelectImplMode] = CharToString(op_select_impl_mode);
617 }
618 
GetOpSelectImplModeChar() const619 std::vector<char> AscendDeviceInfo::GetOpSelectImplModeChar() const {
620   if (data_ == nullptr) {
621     MS_LOG(ERROR) << "Invalid context.";
622     return std::vector<char>();
623   }
624   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendOpSelectImplMode);
625   return StringToChar(ref);
626 }
627 
SetFusionSwitchConfigPath(const std::vector<char> & cfg_path)628 void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
629   if (data_ == nullptr) {
630     MS_LOG(ERROR) << "Invalid context.";
631     return;
632   }
633   data_->params[KModelOptionAscendFusionSwitchCfgPath] = CharToString(cfg_path);
634 }
GetFusionSwitchConfigPathChar() const635 std::vector<char> AscendDeviceInfo::GetFusionSwitchConfigPathChar() const {
636   if (data_ == nullptr) {
637     MS_LOG(ERROR) << "Invalid context.";
638     return std::vector<char>();
639   }
640   const std::string &ref = GetValue<std::string>(data_, KModelOptionAscendFusionSwitchCfgPath);
641   return StringToChar(ref);
642 }
643 
SetInputShapeMap(const std::map<int,std::vector<int>> & shape)644 void AscendDeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
645   if (data_ == nullptr) {
646     MS_LOG(ERROR) << "Invalid context.";
647     return;
648   }
649   data_->params[kModelOptionAscendInputShapeMap] = shape;
650 }
651 
GetInputShapeMap() const652 std::map<int, std::vector<int>> AscendDeviceInfo::GetInputShapeMap() const {
653   if (data_ == nullptr) {
654     MS_LOG(ERROR) << "Invalid context.";
655     return std::map<int, std::vector<int>>();
656   }
657   return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscendInputShapeMap);
658 }
659 
SetOutputType(enum DataType output_type)660 void AscendDeviceInfo::SetOutputType(enum DataType output_type) {
661   if (data_ == nullptr) {
662     MS_LOG(ERROR) << "Invalid context.";
663     return;
664   }
665   data_->params[kModelOptionAscendOutputType] = output_type;
666 }
667 
GetOutputType() const668 enum DataType AscendDeviceInfo::GetOutputType() const {
669   if (data_ == nullptr) {
670     MS_LOG(ERROR) << "Invalid context.";
671     return DataType::kTypeUnknown;
672   }
673   return GetValue<enum DataType>(data_, kModelOptionAscendOutputType);
674 }
675 
SetBufferOptimizeMode(const std::vector<char> & buffer_optimize_mode)676 void AscendDeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
677   if (data_ == nullptr) {
678     MS_LOG(ERROR) << "Invalid context.";
679     return;
680   }
681   data_->params[kModelOptionAscendBufferOptimize] = CharToString(buffer_optimize_mode);
682 }
683 
GetBufferOptimizeModeChar() const684 std::vector<char> AscendDeviceInfo::GetBufferOptimizeModeChar() const {
685   if (data_ == nullptr) {
686     MS_LOG(ERROR) << "Invalid context.";
687     return std::vector<char>();
688   }
689   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscendBufferOptimize);
690   return StringToChar(ref);
691 }
692 
SetDeviceID(size_t device_id)693 void NNRTDeviceInfo::SetDeviceID(size_t device_id) {
694   if (data_ == nullptr) {
695     MS_LOG(ERROR) << "Invalid context.";
696     return;
697   }
698   data_->params[kModelOptionNNRTDeviceID] = device_id;
699 }
700 
GetDeviceID() const701 size_t NNRTDeviceInfo::GetDeviceID() const {
702   if (data_ == nullptr) {
703     MS_LOG(ERROR) << "Invalid context.";
704     return 0;
705   }
706   return GetValue<size_t>(data_, kModelOptionNNRTDeviceID);
707 }
708 
SetPerformanceMode(int performance_mode)709 void NNRTDeviceInfo::SetPerformanceMode(int performance_mode) {
710   if (data_ == nullptr) {
711     MS_LOG(ERROR) << "Invalid context.";
712     return;
713   }
714   data_->params[kModelOptionNNRTPerformanceMode] = performance_mode;
715 }
716 
GetPerformanceMode() const717 int NNRTDeviceInfo::GetPerformanceMode() const {
718   if (data_ == nullptr) {
719     MS_LOG(ERROR) << "Invalid context.";
720     return 0;
721   }
722   return GetValue<int>(data_, kModelOptionNNRTPerformanceMode);
723 }
724 
SetPriority(int priority)725 void NNRTDeviceInfo::SetPriority(int priority) {
726   if (data_ == nullptr) {
727     MS_LOG(ERROR) << "Invalid context.";
728     return;
729   }
730   data_->params[kModelOptionNNRTPriority] = priority;
731 }
732 
GetPriority() const733 int NNRTDeviceInfo::GetPriority() const {
734   if (data_ == nullptr) {
735     MS_LOG(ERROR) << "Invalid context.";
736     return 0;
737   }
738   return GetValue<int>(data_, kModelOptionNNRTPriority);
739 }
740 
SetEnableFP16(bool is_fp16)741 void NNRTDeviceInfo::SetEnableFP16(bool is_fp16) {
742   if (data_ == nullptr) {
743     MS_LOG(ERROR) << "Invalid context.";
744     return;
745   }
746   data_->params[kModelOptionNNRTEnableFP16] = is_fp16;
747 }
748 
GetEnableFP16() const749 bool NNRTDeviceInfo::GetEnableFP16() const {
750   if (data_ == nullptr) {
751     MS_LOG(ERROR) << "Invalid context.";
752     return false;
753   }
754   return GetValue<bool>(data_, kModelOptionNNRTEnableFP16);
755 }
756 
SetExtensions(const std::vector<Extension> & extensions)757 void NNRTDeviceInfo::SetExtensions(const std::vector<Extension> &extensions) {
758   if (data_ == nullptr) {
759     MS_LOG(ERROR) << "Invalid context.";
760     return;
761   }
762   data_->params[kModelOptionNNRTExtensions] = extensions;
763 }
764 
GetExtensions() const765 std::vector<Extension> NNRTDeviceInfo::GetExtensions() const {
766   if (data_ == nullptr) {
767     MS_LOG(ERROR) << "Invalid context.";
768     return {};
769   }
770   return GetValue<std::vector<Extension>>(data_, kModelOptionNNRTExtensions);
771 }
772 }  // namespace mindspore
773