1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/cxx_api/context.h"
17 #include <string>
18 #include <memory>
19 #include "include/api/types.h"
20 #include "include/api/data_type.h"
21 #include "src/runtime/inner_allocator.h"
22 #include "src/common/log_adapter.h"
23
24 namespace mindspore {
25 constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
26 constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
27 constexpr auto kModelOptionGPUDeviceID = "mindspore.option.gpu.device_id";
28 constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
29 constexpr auto kModelOptionProvider = "mindspore.option.provider";
30 constexpr auto kModelOptionProviderDevice = "mindspore.option.provider.device";
31 constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
32 constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
33 constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
34 constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
35 constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
36 constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
37 constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type";
38 constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
39 constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
40 constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
41 constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
42 constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
43
Context()44 Context::Context() : data_(std::shared_ptr<Data>(new (std::nothrow) Data())) {}
45
46 template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
GetValue(const std::shared_ptr<DeviceInfoContext::Data> & data,const std::string & key)47 static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
48 static U empty_result;
49 if (data == nullptr) {
50 return empty_result;
51 }
52 auto iter = data->params.find(key);
53 if (iter == data->params.end()) {
54 return empty_result;
55 }
56 #ifndef SUPPORT_NNIE
57 const std::any &value = iter->second;
58 return std::any_cast<const U &>(value);
59 #else
60 const std::experimental::any &value = iter->second;
61 return std::experimental::any_cast<const U &>(value);
62 #endif
63 }
64
SetThreadNum(int32_t thread_num)65 void Context::SetThreadNum(int32_t thread_num) {
66 if (data_ == nullptr) {
67 MS_LOG(ERROR) << "Invalid context.";
68 return;
69 }
70 data_->thread_num = thread_num;
71 }
72
GetThreadNum() const73 int32_t Context::GetThreadNum() const {
74 if (data_ == nullptr) {
75 MS_LOG(ERROR) << "Invalid context.";
76 return 0;
77 }
78 return data_->thread_num;
79 }
80
SetEnableParallel(bool is_parallel)81 void Context::SetEnableParallel(bool is_parallel) {
82 if (data_ == nullptr) {
83 MS_LOG(ERROR) << "Invalid context.";
84 return;
85 }
86 data_->enable_parallel_ = is_parallel;
87 }
88
GetEnableParallel() const89 bool Context::GetEnableParallel() const {
90 if (data_ == nullptr) {
91 MS_LOG(ERROR) << "Invalid context.";
92 return false;
93 }
94
95 return data_->enable_parallel_;
96 }
97
SetThreadAffinity(int mode)98 void Context::SetThreadAffinity(int mode) {
99 if (data_ == nullptr) {
100 MS_LOG(ERROR) << "Invalid context.";
101 return;
102 }
103 data_->affinity_mode_ = mode;
104 return;
105 }
106
GetThreadAffinityMode() const107 int Context::GetThreadAffinityMode() const {
108 if (data_ == nullptr) {
109 MS_LOG(ERROR) << "Invalid context.";
110 return -1;
111 }
112 return data_->affinity_mode_;
113 }
114
SetThreadAffinity(const std::vector<int> & core_list)115 void Context::SetThreadAffinity(const std::vector<int> &core_list) {
116 if (data_ == nullptr) {
117 MS_LOG(ERROR) << "Invalid context.";
118 return;
119 }
120 data_->affinity_core_list_ = core_list;
121
122 return;
123 }
124
GetThreadAffinityCoreList() const125 std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
126 if (data_ == nullptr) {
127 MS_LOG(ERROR) << "Invalid context.";
128 return {};
129 }
130 return data_->affinity_core_list_;
131 }
132
SetDelegate(const std::shared_ptr<Delegate> & delegate)133 void Context::SetDelegate(const std::shared_ptr<Delegate> &delegate) {
134 if (data_ == nullptr) {
135 MS_LOG(ERROR) << "Invalid context.";
136 return;
137 }
138 data_->delegate = delegate;
139 }
140
GetDelegate() const141 std::shared_ptr<Delegate> Context::GetDelegate() const {
142 if (data_ == nullptr) {
143 MS_LOG(ERROR) << "Invalid context.";
144 return nullptr;
145 }
146 return data_->delegate;
147 }
148
MutableDeviceInfo()149 std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
150 static std::vector<std::shared_ptr<DeviceInfoContext>> empty;
151 if (data_ == nullptr) {
152 MS_LOG(ERROR) << "Invalid context.";
153 return empty;
154 }
155 return data_->device_info_list;
156 }
157
DeviceInfoContext()158 DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}
159
GetProvider() const160 std::string DeviceInfoContext::GetProvider() const {
161 if (data_ == nullptr) {
162 MS_LOG(ERROR) << "Invalid context.";
163 return "";
164 }
165 return GetValue<std::string>(data_, kModelOptionProvider);
166 }
167
SetProvider(const std::string & provider)168 void DeviceInfoContext::SetProvider(const std::string &provider) {
169 if (data_ == nullptr) {
170 MS_LOG(ERROR) << "Invalid context.";
171 return;
172 }
173 data_->params[kModelOptionProvider] = provider;
174 }
175
GetProviderDevice() const176 std::string DeviceInfoContext::GetProviderDevice() const {
177 if (data_ == nullptr) {
178 MS_LOG(ERROR) << "Invalid context.";
179 return "";
180 }
181 return GetValue<std::string>(data_, kModelOptionProviderDevice);
182 }
183
SetProviderDevice(const std::string & device)184 void DeviceInfoContext::SetProviderDevice(const std::string &device) {
185 if (data_ == nullptr) {
186 MS_LOG(ERROR) << "Invalid context.";
187 return;
188 }
189 data_->params[kModelOptionProviderDevice] = device;
190 }
191
SetAllocator(const std::shared_ptr<Allocator> & allocator)192 void DeviceInfoContext::SetAllocator(const std::shared_ptr<Allocator> &allocator) {
193 if (data_ == nullptr) {
194 MS_LOG(ERROR) << "Invalid context.";
195 return;
196 }
197 data_->allocator = allocator;
198 }
199
GetAllocator() const200 std::shared_ptr<Allocator> DeviceInfoContext::GetAllocator() const {
201 if (data_ == nullptr) {
202 MS_LOG(ERROR) << "Invalid context.";
203 return nullptr;
204 }
205 return data_->allocator;
206 }
207
SetEnableFP16(bool is_fp16)208 void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
209 if (data_ == nullptr) {
210 MS_LOG(ERROR) << "Invalid context.";
211 return;
212 }
213 data_->params[kModelOptionCpuEnableFP16] = is_fp16;
214 }
215
GetEnableFP16() const216 bool CPUDeviceInfo::GetEnableFP16() const {
217 if (data_ == nullptr) {
218 MS_LOG(ERROR) << "Invalid context.";
219 return false;
220 }
221 return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
222 }
223
SetEnableFP16(bool is_fp16)224 void GPUDeviceInfo::SetEnableFP16(bool is_fp16) {
225 if (data_ == nullptr) {
226 MS_LOG(ERROR) << "Invalid context.";
227 return;
228 }
229 data_->params[kModelOptionGPUEnableFP16] = is_fp16;
230 }
GetEnableFP16() const231 bool GPUDeviceInfo::GetEnableFP16() const {
232 if (data_ == nullptr) {
233 MS_LOG(ERROR) << "Invalid context.";
234 return false;
235 }
236 return GetValue<bool>(data_, kModelOptionGPUEnableFP16);
237 }
238
SetFrequency(int frequency)239 void KirinNPUDeviceInfo::SetFrequency(int frequency) {
240 if (data_ == nullptr) {
241 MS_LOG(ERROR) << "Invalid context.";
242 return;
243 }
244 data_->params[kModelOptionKirinNpuFrequency] = frequency;
245 }
246
GetFrequency() const247 int KirinNPUDeviceInfo::GetFrequency() const {
248 if (data_ == nullptr) {
249 MS_LOG(ERROR) << "Invalid context.";
250 return 0;
251 }
252 return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
253 }
254
SetDeviceID(uint32_t device_id)255 void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
256 if (data_ == nullptr) {
257 MS_LOG(ERROR) << "Invalid context.";
258 return;
259 }
260 data_->params[kModelOptionGPUDeviceID] = device_id;
261 }
262
GetDeviceID() const263 uint32_t GPUDeviceInfo::GetDeviceID() const {
264 if (data_ == nullptr) {
265 MS_LOG(ERROR) << "Invalid context.";
266 return 0;
267 }
268 return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
269 }
270
SetPrecisionMode(const std::vector<char> & precision_mode)271 void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
272 MS_LOG(ERROR) << "Unsupported Feature.";
273 }
GetPrecisionModeChar() const274 std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
275 MS_LOG(ERROR) << "Unsupported Feature.";
276 std::vector<char> ret;
277 return ret;
278 }
279
SetDeviceID(uint32_t device_id)280 void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) { MS_LOG(ERROR) << "Unsupported Feature."; }
GetDeviceID() const281 uint32_t Ascend910DeviceInfo::GetDeviceID() const {
282 MS_LOG(ERROR) << "Unsupported Feature.";
283 return 0;
284 }
285
SetDeviceID(uint32_t device_id)286 void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
287 if (data_ == nullptr) {
288 MS_LOG(ERROR) << "Invalid context.";
289 return;
290 }
291 data_->params[kModelOptionAscend310DeviceID] = device_id;
292 }
293
GetDeviceID() const294 uint32_t Ascend310DeviceInfo::GetDeviceID() const {
295 if (data_ == nullptr) {
296 MS_LOG(ERROR) << "Invalid context.";
297 return 0;
298 }
299 return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
300 }
301
SetInsertOpConfigPath(const std::vector<char> & cfg_path)302 void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
303 if (data_ == nullptr) {
304 MS_LOG(ERROR) << "Invalid context.";
305 return;
306 }
307 data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
308 }
GetInsertOpConfigPathChar() const309 std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
310 if (data_ == nullptr) {
311 MS_LOG(ERROR) << "Invalid context.";
312 return std::vector<char>();
313 }
314 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
315 return StringToChar(ref);
316 }
317
SetInputFormat(const std::vector<char> & format)318 void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
319 if (data_ == nullptr) {
320 MS_LOG(ERROR) << "Invalid context.";
321 return;
322 }
323 data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
324 }
325
GetInputFormatChar() const326 std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
327 if (data_ == nullptr) {
328 MS_LOG(ERROR) << "Invalid context.";
329 return std::vector<char>();
330 }
331 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
332 return StringToChar(ref);
333 }
334
SetInputShape(const std::vector<char> & shape)335 void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
336 if (data_ == nullptr) {
337 MS_LOG(ERROR) << "Invalid context.";
338 return;
339 }
340 data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
341 }
GetInputShapeChar() const342 std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
343 if (data_ == nullptr) {
344 MS_LOG(ERROR) << "Invalid context.";
345 return std::vector<char>();
346 }
347 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
348 return StringToChar(ref);
349 }
350
SetDynamicBatchSize(const std::vector<size_t> & dynamic_batch_size)351 void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
352 if (data_ == nullptr) {
353 MS_LOG(ERROR) << "Invalid context.";
354 return;
355 }
356 std::string batchs;
357 for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
358 if (i != 0) {
359 batchs.push_back(',');
360 }
361 batchs += std::to_string(dynamic_batch_size[i]);
362 }
363 data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
364 }
365
GetDynamicBatchSizeChar() const366 std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
367 if (data_ == nullptr) {
368 MS_LOG(ERROR) << "Invalid context.";
369 return std::vector<char>();
370 }
371 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
372 return StringToChar(ref);
373 }
374
SetPrecisionMode(const std::vector<char> & precision_mode)375 void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
376 if (data_ == nullptr) {
377 MS_LOG(ERROR) << "Invalid context.";
378 return;
379 }
380 data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
381 }
382
GetPrecisionModeChar() const383 std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
384 if (data_ == nullptr) {
385 MS_LOG(ERROR) << "Invalid context.";
386 return std::vector<char>();
387 }
388 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
389 return StringToChar(ref);
390 }
391
SetOpSelectImplMode(const std::vector<char> & op_select_impl_mode)392 void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
393 if (data_ == nullptr) {
394 MS_LOG(ERROR) << "Invalid context.";
395 return;
396 }
397 data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
398 }
399
GetOpSelectImplModeChar() const400 std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
401 if (data_ == nullptr) {
402 MS_LOG(ERROR) << "Invalid context.";
403 return std::vector<char>();
404 }
405 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
406 return StringToChar(ref);
407 }
408
SetFusionSwitchConfigPath(const std::vector<char> & cfg_path)409 void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
410 if (data_ == nullptr) {
411 MS_LOG(ERROR) << "Invalid context.";
412 return;
413 }
414 data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
415 }
GetFusionSwitchConfigPathChar() const416 std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
417 if (data_ == nullptr) {
418 MS_LOG(ERROR) << "Invalid context.";
419 return std::vector<char>();
420 }
421 const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
422 return StringToChar(ref);
423 }
424
SetInputShapeMap(const std::map<int,std::vector<int>> & shape)425 void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
426 if (data_ == nullptr) {
427 MS_LOG(ERROR) << "Invalid context.";
428 return;
429 }
430 data_->params[kModelOptionAscend310InputShapeMap] = shape;
431 }
432
GetInputShapeMap() const433 std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
434 if (data_ == nullptr) {
435 MS_LOG(ERROR) << "Invalid context.";
436 return std::map<int, std::vector<int>>();
437 }
438 return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
439 }
440
SetOutputType(enum DataType output_type)441 void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
442 if (data_ == nullptr) {
443 MS_LOG(ERROR) << "Invalid context.";
444 return;
445 }
446 data_->params[kModelOptionAscend310OutputType] = output_type;
447 }
448
GetOutputType() const449 enum DataType Ascend310DeviceInfo::GetOutputType() const {
450 if (data_ == nullptr) {
451 MS_LOG(ERROR) << "Invalid context.";
452 return DataType::kTypeUnknown;
453 }
454 return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
455 }
456
SetBufferOptimizeMode(const std::vector<char> & buffer_optimize_mode)457 void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
458 if (data_ == nullptr) {
459 MS_LOG(ERROR) << "Invalid context.";
460 return;
461 }
462 data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
463 }
464
GetBufferOptimizeModeChar() const465 std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
466 if (data_ == nullptr) {
467 MS_LOG(ERROR) << "Invalid context.";
468 return std::vector<char>();
469 }
470 const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
471 return StringToChar(ref);
472 }
473 } // namespace mindspore
474