• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/cxx_api/context.h"
17 #include <string>
18 #include <memory>
19 #include "include/api/types.h"
20 #include "include/api/data_type.h"
21 #include "src/runtime/inner_allocator.h"
22 #include "src/common/log_adapter.h"
23 
24 namespace mindspore {
25 constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
26 constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
27 constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id";
28 constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
29 constexpr auto kModelOptionProvider = "mindspore.option.provider";
30 constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
31 constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
32 constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
33 constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
34 constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
35 constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
36 constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
37 constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type";
38 constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
39 constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
40 constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
41 constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
42 constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
43 
Context()44 Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
45 
46 template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
GetValue(const std::shared_ptr<DeviceInfoContext::Data> & data,const std::string & key)47 static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
48   static U empty_result;
49   if (data == nullptr) {
50     return empty_result;
51   }
52   auto iter = data->params.find(key);
53   if (iter == data->params.end()) {
54     return empty_result;
55   }
56 #ifndef SUPPORT_NNIE
57   const std::any &value = iter->second;
58   return std::any_cast<const U &>(value);
59 #else
60   const std::experimental::any &value = iter->second;
61   return std::experimental::any_cast<const U &>(value);
62 #endif
63 }
64 
SetThreadNum(int32_t thread_num)65 void Context::SetThreadNum(int32_t thread_num) {
66   if (data_ == nullptr) {
67     MS_LOG(ERROR) << "Invalid context.";
68     return;
69   }
70   data_->thread_num = thread_num;
71 }
72 
GetThreadNum() const73 int32_t Context::GetThreadNum() const {
74   if (data_ == nullptr) {
75     MS_LOG(ERROR) << "Invalid context.";
76     return 0;
77   }
78   return data_->thread_num;
79 }
80 
SetEnableParallel(bool is_parallel)81 void Context::SetEnableParallel(bool is_parallel) {
82   if (data_ == nullptr) {
83     MS_LOG(ERROR) << "Invalid context.";
84     return;
85   }
86   data_->enable_parallel_ = is_parallel;
87 }
88 
GetEnableParallel() const89 bool Context::GetEnableParallel() const {
90   if (data_ == nullptr) {
91     MS_LOG(ERROR) << "Invalid context.";
92     return false;
93   }
94 
95   return data_->enable_parallel_;
96 }
97 
SetThreadAffinity(int mode)98 void Context::SetThreadAffinity(int mode) {
99   if (data_ == nullptr) {
100     MS_LOG(ERROR) << "Invalid context.";
101     return;
102   }
103   data_->affinity_mode_ = mode;
104   return;
105 }
106 
GetThreadAffinityMode() const107 int Context::GetThreadAffinityMode() const {
108   if (data_ == nullptr) {
109     MS_LOG(ERROR) << "Invalid context.";
110     return -1;
111   }
112   return data_->affinity_mode_;
113 }
114 
SetThreadAffinity(const std::vector<int> & core_list)115 void Context::SetThreadAffinity(const std::vector<int> &core_list) {
116   if (data_ == nullptr) {
117     MS_LOG(ERROR) << "Invalid context.";
118     return;
119   }
120   data_->affinity_core_list_ = core_list;
121 
122   return;
123 }
124 
GetThreadAffinityCoreList() const125 std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
126   if (data_ == nullptr) {
127     MS_LOG(ERROR) << "Invalid context.";
128     return {};
129   }
130   return data_->affinity_core_list_;
131 }
132 
SetDelegate(const std::shared_ptr<Delegate> & delegate)133 void Context::SetDelegate(const std::shared_ptr<Delegate> &delegate) {
134   if (data_ == nullptr) {
135     MS_LOG(ERROR) << "Invalid context.";
136     return;
137   }
138   data_->delegate = delegate;
139 }
140 
GetDelegate() const141 std::shared_ptr<Delegate> Context::GetDelegate() const {
142   if (data_ == nullptr) {
143     MS_LOG(ERROR) << "Invalid context.";
144     return nullptr;
145   }
146   return data_->delegate;
147 }
148 
MutableDeviceInfo()149 std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
150   static std::vector<std::shared_ptr<DeviceInfoContext>> empty;
151   if (data_ == nullptr) {
152     MS_LOG(ERROR) << "Invalid context.";
153     return empty;
154   }
155   return data_->device_info_list;
156 }
157 
DeviceInfoContext()158 DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}
159 
GetProvider() const160 std::string DeviceInfoContext::GetProvider() const {
161   if (data_ == nullptr) {
162     MS_LOG(ERROR) << "Invalid context.";
163     return "";
164   }
165   return GetValue<std::string>(data_, kModelOptionProvider);
166 }
167 
SetProvider(const std::string & provider)168 void DeviceInfoContext::SetProvider(const std::string &provider) {
169   if (data_ == nullptr) {
170     MS_LOG(ERROR) << "Invalid context.";
171     return;
172   }
173   data_->params[kModelOptionProvider] = provider;
174 }
175 
GetProviderDevice() const176 std::string DeviceInfoContext::GetProviderDevice() const {
177   if (data_ == nullptr) {
178     MS_LOG(ERROR) << "Invalid context.";
179     return "";
180   }
181   return GetValue<std::string>(data_, kModelOptionProviderDevice);
182 }
183 
SetProviderDevice(const std::string & device)184 void DeviceInfoContext::SetProviderDevice(const std::string &device) {
185   if (data_ == nullptr) {
186     MS_LOG(ERROR) << "Invalid context.";
187     return;
188   }
189   data_->params[kModelOptionProviderDevice] = device;
190 }
191 
SetAllocator(const std::shared_ptr<Allocator> & allocator)192 void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
193   if (data_ == nullptr) {
194     MS_LOG(ERROR) << "Invalid context.";
195     return;
196   }
197   data_->allocator = allocator;
198 }
199 
GetAllocator() const200 std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
201   if (data_ == nullptr) {
202     MS_LOG(ERROR) << "Invalid context.";
203     return nullptr;
204   }
205   return data_->allocator;
206 }
207 
SetEnableFP16(bool is_fp16)208 void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
209   if (data_ == nullptr) {
210     MS_LOG(ERROR) << "Invalid context.";
211     return;
212   }
213   data_->params[kModelOptionCpuEnableFP16] = is_fp16;
214 }
215 
GetEnableFP16() const216 bool CPUDeviceInfo::GetEnableFP16() const {
217   if (data_ == nullptr) {
218     MS_LOG(ERROR) << "Invalid context.";
219     return false;
220   }
221   return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
222 }
223 
SetEnableFP16(bool is_fp16)224 void GPUDeviceInfo::SetEnableFP16(bool is_fp16) {
225   if (data_ == nullptr) {
226     MS_LOG(ERROR) << "Invalid context.";
227     return;
228   }
229   data_->params[kModelOptionGPUEnableFP16] = is_fp16;
230 }
GetEnableFP16() const231 bool GPUDeviceInfo::GetEnableFP16() const {
232   if (data_ == nullptr) {
233     MS_LOG(ERROR) << "Invalid context.";
234     return false;
235   }
236   return GetValue<bool>(data_, kModelOptionGPUEnableFP16);
237 }
238 
SetFrequency(int frequency)239 void KirinNPUDeviceInfo::SetFrequency(int frequency) {
240   if (data_ == nullptr) {
241     MS_LOG(ERROR) << "Invalid context.";
242     return;
243   }
244   data_->params[kModelOptionKirinNpuFrequency] = frequency;
245 }
246 
GetFrequency() const247 int KirinNPUDeviceInfo::GetFrequency() const {
248   if (data_ == nullptr) {
249     MS_LOG(ERROR) << "Invalid context.";
250     return 0;
251   }
252   return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
253 }
254 
SetDeviceID(uint32_t device_id)255 void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
256   if (data_ == nullptr) {
257     MS_LOG(ERROR) << "Invalid context.";
258     return;
259   }
260   data_->params[kModelOptionGPUDeviceID] = device_id;
261 }
262 
GetDeviceID() const263 uint32_t GPUDeviceInfo::GetDeviceID() const {
264   if (data_ == nullptr) {
265     MS_LOG(ERROR) << "Invalid context.";
266     return 0;
267   }
268   return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
269 }
270 
SetPrecisionMode(const std::vector<char> & precision_mode)271 void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
272   MS_LOG(ERROR) << "Unsupported Feature.";
273 }
GetPrecisionModeChar() const274 std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
275   MS_LOG(ERROR) << "Unsupported Feature.";
276   std::vector<char> ret;
277   return ret;
278 }
279 
SetDeviceID(uint32_t device_id)280 void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
GetDeviceID() const281 uint32_t Ascend910DeviceInfo::GetDeviceID() const {
282   MS_LOG(ERROR) << "Unsupported Feature.";
283   return 0;
284 }
285 
SetDeviceID(uint32_t device_id)286 void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
287   if (data_ == nullptr) {
288     MS_LOG(ERROR) << "Invalid context.";
289     return;
290   }
291   data_->params[kModelOptionAscend310DeviceID] = device_id;
292 }
293 
GetDeviceID() const294 uint32_t Ascend310DeviceInfo::GetDeviceID() const {
295   if (data_ == nullptr) {
296     MS_LOG(ERROR) << "Invalid context.";
297     return 0;
298   }
299   return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
300 }
301 
SetInsertOpConfigPath(const std::vector<char> & cfg_path)302 void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
303   if (data_ == nullptr) {
304     MS_LOG(ERROR) << "Invalid context.";
305     return;
306   }
307   data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
308 }
GetInsertOpConfigPathChar() const309 std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
310   if (data_ == nullptr) {
311     MS_LOG(ERROR) << "Invalid context.";
312     return std::vector<char>();
313   }
314   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
315   return StringToChar(ref);
316 }
317 
SetInputFormat(const std::vector<char> & format)318 void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
319   if (data_ == nullptr) {
320     MS_LOG(ERROR) << "Invalid context.";
321     return;
322   }
323   data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
324 }
325 
GetInputFormatChar() const326 std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
327   if (data_ == nullptr) {
328     MS_LOG(ERROR) << "Invalid context.";
329     return std::vector<char>();
330   }
331   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
332   return StringToChar(ref);
333 }
334 
SetInputShape(const std::vector<char> & shape)335 void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
336   if (data_ == nullptr) {
337     MS_LOG(ERROR) << "Invalid context.";
338     return;
339   }
340   data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
341 }
GetInputShapeChar() const342 std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
343   if (data_ == nullptr) {
344     MS_LOG(ERROR) << "Invalid context.";
345     return std::vector<char>();
346   }
347   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
348   return StringToChar(ref);
349 }
350 
SetDynamicBatchSize(const std::vector<size_t> & dynamic_batch_size)351 void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
352   if (data_ == nullptr) {
353     MS_LOG(ERROR) << "Invalid context.";
354     return;
355   }
356   std::string batchs;
357   for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
358     if (i != 0) {
359       batchs.push_back(',');
360     }
361     batchs += std::to_string(dynamic_batch_size[i]);
362   }
363   data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
364 }
365 
GetDynamicBatchSizeChar() const366 std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
367   if (data_ == nullptr) {
368     MS_LOG(ERROR) << "Invalid context.";
369     return std::vector<char>();
370   }
371   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
372   return StringToChar(ref);
373 }
374 
SetPrecisionMode(const std::vector<char> & precision_mode)375 void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
376   if (data_ == nullptr) {
377     MS_LOG(ERROR) << "Invalid context.";
378     return;
379   }
380   data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
381 }
382 
GetPrecisionModeChar() const383 std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
384   if (data_ == nullptr) {
385     MS_LOG(ERROR) << "Invalid context.";
386     return std::vector<char>();
387   }
388   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
389   return StringToChar(ref);
390 }
391 
SetOpSelectImplMode(const std::vector<char> & op_select_impl_mode)392 void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
393   if (data_ == nullptr) {
394     MS_LOG(ERROR) << "Invalid context.";
395     return;
396   }
397   data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
398 }
399 
GetOpSelectImplModeChar() const400 std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
401   if (data_ == nullptr) {
402     MS_LOG(ERROR) << "Invalid context.";
403     return std::vector<char>();
404   }
405   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
406   return StringToChar(ref);
407 }
408 
SetFusionSwitchConfigPath(const std::vector<char> & cfg_path)409 void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
410   if (data_ == nullptr) {
411     MS_LOG(ERROR) << "Invalid context.";
412     return;
413   }
414   data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
415 }
GetFusionSwitchConfigPathChar() const416 std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
417   if (data_ == nullptr) {
418     MS_LOG(ERROR) << "Invalid context.";
419     return std::vector<char>();
420   }
421   const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
422   return StringToChar(ref);
423 }
424 
SetInputShapeMap(const std::map<int,std::vector<int>> & shape)425 void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
426   if (data_ == nullptr) {
427     MS_LOG(ERROR) << "Invalid context.";
428     return;
429   }
430   data_->params[kModelOptionAscend310InputShapeMap] = shape;
431 }
432 
GetInputShapeMap() const433 std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
434   if (data_ == nullptr) {
435     MS_LOG(ERROR) << "Invalid context.";
436     return std::map<int, std::vector<int>>();
437   }
438   return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
439 }
440 
SetOutputType(enum DataType output_type)441 void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
442   if (data_ == nullptr) {
443     MS_LOG(ERROR) << "Invalid context.";
444     return;
445   }
446   data_->params[kModelOptionAscend310OutputType] = output_type;
447 }
448 
GetOutputType() const449 enum DataType Ascend310DeviceInfo::GetOutputType() const {
450   if (data_ == nullptr) {
451     MS_LOG(ERROR) << "Invalid context.";
452     return DataType::kTypeUnknown;
453   }
454   return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
455 }
456 
SetBufferOptimizeMode(const std::vector<char> & buffer_optimize_mode)457 void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
458   if (data_ == nullptr) {
459     MS_LOG(ERROR) << "Invalid context.";
460     return;
461   }
462   data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
463 }
464 
GetBufferOptimizeModeChar() const465 std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
466   if (data_ == nullptr) {
467     MS_LOG(ERROR) << "Invalid context.";
468     return std::vector<char>();
469   }
470   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
471   return StringToChar(ref);
472 }
473 }  // namespace mindspore
474