• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/context.h"
17 #include <any>
18 #include <map>
19 #include <type_traits>
20 #include "cxx_api/factory.h"
21 #include "utils/log_adapter.h"
22 
23 constexpr auto kModelOptionCpuEnableFP16 = "mindspore.option.cpu.enable_fp16";
24 constexpr auto kModelOptionGPUEnableFP16 = "mindspore.option.gpu.enable_fp16";
25 constexpr auto kModelOptionKirinNpuFrequency = "mindspore.option.kirin_npu.frequency";
26 constexpr auto kModelOptionDeviceID = "mindspore.option.device_id";
27 constexpr auto kModelOptionGPUDeviceID = kModelOptionDeviceID;
28 constexpr auto kModelOptionGPUPrecisionMode = "mindspore.option.gpu.precision_mode";
29 constexpr auto kModelOptionAscend910DeviceID = kModelOptionDeviceID;
30 constexpr auto kModelOptionAscend310DeviceID = kModelOptionDeviceID;
31 constexpr auto kModelOptionAscend310InsertOpCfgPath = "mindspore.option.ascend310.insert_op_config_file_path";
32 constexpr auto kModelOptionAscend310InputFormat = "mindspore.option.ascend310.input_format";
33 constexpr auto kModelOptionAscend310InputShapeMap = "mindspore.option.ascend310.input_shape_map";
34 constexpr auto kModelOptionAscend310InputShape = "mindspore.option.ascend310.input_shape";
35 constexpr auto kModelOptionAscend310OutputType = "mindspore.option.ascend310.output_type";
36 constexpr auto kModelOptionAscend310PrecisionMode = "mindspore.option.ascend310.precision_mode";
37 constexpr auto kModelOptionAscend310OpSelectImplMode = "mindspore.option.ascend310.op_select_impl_mode";
38 constexpr auto KModelOptionAscend310FusionSwitchCfgPath = "mindspore.option.ascend310.fusion_switch_config_file_path";
39 constexpr auto kModelOptionAscend310DynamicBatchSize = "mindspore.option.ascend310.dynamic_batch_size";
40 constexpr auto kModelOptionAscend310BufferOptimize = "mindspore.option.ascend310.buffer_optimize";
41 
42 namespace mindspore {
43 class Allocator {};
44 
45 struct Context::Data {
46   std::vector<std::shared_ptr<DeviceInfoContext>> device_info_list;
47   int32_t thread_num;
48   bool enable_parallel_ = false;
49   std::vector<int32_t> affinity_core_list_;
50   int affinity_mode_ = 2;
51 };
52 
53 struct DeviceInfoContext::Data {
54   std::map<std::string, std::any> params;
55 };
56 
Context()57 Context::Context() : data_(std::make_shared<Data>()) {}
58 
59 template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
GetValue(const std::shared_ptr<DeviceInfoContext::Data> & data,const std::string & key)60 static const U &GetValue(const std::shared_ptr<DeviceInfoContext::Data> &data, const std::string &key) {
61   static const U empty_result{};
62   if (data == nullptr) {
63     return empty_result;
64   }
65   auto iter = data->params.find(key);
66   if (iter == data->params.end()) {
67     return empty_result;
68   }
69   const std::any &value = iter->second;
70   if (value.type() != typeid(U)) {
71     return empty_result;
72   }
73 
74   return std::any_cast<const U &>(value);
75 }
76 
SetThreadNum(int32_t thread_num)77 void Context::SetThreadNum(int32_t thread_num) {
78   MS_EXCEPTION_IF_NULL(data_);
79   data_->thread_num = thread_num;
80 }
GetThreadNum() const81 int32_t Context::GetThreadNum() const {
82   MS_EXCEPTION_IF_NULL(data_);
83   return data_->thread_num;
84 }
85 
SetEnableParallel(bool is_parallel)86 void Context::SetEnableParallel(bool is_parallel) {
87   MS_EXCEPTION_IF_NULL(data_);
88   data_->enable_parallel_ = is_parallel;
89 }
90 
GetEnableParallel() const91 bool Context::GetEnableParallel() const {
92   MS_EXCEPTION_IF_NULL(data_);
93   return data_->enable_parallel_;
94 }
95 
SetThreadAffinity(int mode)96 void Context::SetThreadAffinity(int mode) {
97   MS_EXCEPTION_IF_NULL(data_);
98   data_->affinity_mode_ = mode;
99 }
GetThreadAffinityMode() const100 int Context::GetThreadAffinityMode() const {
101   MS_EXCEPTION_IF_NULL(data_);
102   return data_->affinity_mode_;
103 }
104 
SetThreadAffinity(const std::vector<int> & core_list)105 void Context::SetThreadAffinity(const std::vector<int> &core_list) {
106   MS_EXCEPTION_IF_NULL(data_);
107   data_->affinity_core_list_ = core_list;
108 }
GetThreadAffinityCoreList() const109 std::vector<int32_t> Context::GetThreadAffinityCoreList() const {
110   MS_EXCEPTION_IF_NULL(data_);
111   return data_->affinity_core_list_;
112 }
113 
MutableDeviceInfo()114 std::vector<std::shared_ptr<DeviceInfoContext>> &Context::MutableDeviceInfo() {
115   MS_EXCEPTION_IF_NULL(data_);
116   return data_->device_info_list;
117 }
118 
DeviceInfoContext()119 DeviceInfoContext::DeviceInfoContext() : data_(std::make_shared<Data>()) {}
120 
SetEnableFP16(bool is_fp16)121 void CPUDeviceInfo::SetEnableFP16(bool is_fp16) {
122   MS_EXCEPTION_IF_NULL(data_);
123   data_->params[kModelOptionCpuEnableFP16] = is_fp16;
124 }
GetEnableFP16() const125 bool CPUDeviceInfo::GetEnableFP16() const {
126   MS_EXCEPTION_IF_NULL(data_);
127   return GetValue<bool>(data_, kModelOptionCpuEnableFP16);
128 }
129 
SetEnableFP16(bool is_fp16)130 void GPUDeviceInfo::SetEnableFP16(bool is_fp16) {
131   MS_EXCEPTION_IF_NULL(data_);
132   data_->params[kModelOptionGPUEnableFP16] = is_fp16;
133 }
GetEnableFP16() const134 bool GPUDeviceInfo::GetEnableFP16() const {
135   MS_EXCEPTION_IF_NULL(data_);
136   return GetValue<bool>(data_, kModelOptionGPUEnableFP16);
137 }
138 
SetFrequency(int frequency)139 void KirinNPUDeviceInfo::SetFrequency(int frequency) {
140   MS_EXCEPTION_IF_NULL(data_);
141   data_->params[kModelOptionKirinNpuFrequency] = frequency;
142 }
GetFrequency() const143 int KirinNPUDeviceInfo::GetFrequency() const {
144   MS_EXCEPTION_IF_NULL(data_);
145   return GetValue<int>(data_, kModelOptionKirinNpuFrequency);
146 }
147 
SetDeviceID(uint32_t device_id)148 void GPUDeviceInfo::SetDeviceID(uint32_t device_id) {
149   MS_EXCEPTION_IF_NULL(data_);
150   data_->params[kModelOptionGPUDeviceID] = device_id;
151 }
GetDeviceID() const152 uint32_t GPUDeviceInfo::GetDeviceID() const {
153   MS_EXCEPTION_IF_NULL(data_);
154   return GetValue<uint32_t>(data_, kModelOptionGPUDeviceID);
155 }
156 
SetPrecisionMode(const std::vector<char> & precision_mode)157 void GPUDeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
158   MS_EXCEPTION_IF_NULL(data_);
159   data_->params[kModelOptionGPUPrecisionMode] = CharToString(precision_mode);
160 }
GetPrecisionModeChar() const161 std::vector<char> GPUDeviceInfo::GetPrecisionModeChar() const {
162   MS_EXCEPTION_IF_NULL(data_);
163   const std::string &ref = GetValue<std::string>(data_, kModelOptionGPUPrecisionMode);
164   return StringToChar(ref);
165 }
166 
SetDeviceID(uint32_t device_id)167 void Ascend910DeviceInfo::SetDeviceID(uint32_t device_id) {
168   MS_EXCEPTION_IF_NULL(data_);
169   data_->params[kModelOptionAscend910DeviceID] = device_id;
170 }
GetDeviceID() const171 uint32_t Ascend910DeviceInfo::GetDeviceID() const {
172   MS_EXCEPTION_IF_NULL(data_);
173   return GetValue<uint32_t>(data_, kModelOptionAscend910DeviceID);
174 }
175 
SetDeviceID(uint32_t device_id)176 void Ascend310DeviceInfo::SetDeviceID(uint32_t device_id) {
177   MS_EXCEPTION_IF_NULL(data_);
178   data_->params[kModelOptionAscend310DeviceID] = device_id;
179 }
GetDeviceID() const180 uint32_t Ascend310DeviceInfo::GetDeviceID() const {
181   MS_EXCEPTION_IF_NULL(data_);
182   return GetValue<uint32_t>(data_, kModelOptionAscend310DeviceID);
183 }
184 
SetInsertOpConfigPath(const std::vector<char> & cfg_path)185 void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::vector<char> &cfg_path) {
186   MS_EXCEPTION_IF_NULL(data_);
187   data_->params[kModelOptionAscend310InsertOpCfgPath] = CharToString(cfg_path);
188 }
GetInsertOpConfigPathChar() const189 std::vector<char> Ascend310DeviceInfo::GetInsertOpConfigPathChar() const {
190   MS_EXCEPTION_IF_NULL(data_);
191   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InsertOpCfgPath);
192   return StringToChar(ref);
193 }
194 
SetInputFormat(const std::vector<char> & format)195 void Ascend310DeviceInfo::SetInputFormat(const std::vector<char> &format) {
196   MS_EXCEPTION_IF_NULL(data_);
197   data_->params[kModelOptionAscend310InputFormat] = CharToString(format);
198 }
GetInputFormatChar() const199 std::vector<char> Ascend310DeviceInfo::GetInputFormatChar() const {
200   MS_EXCEPTION_IF_NULL(data_);
201   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputFormat);
202   return StringToChar(ref);
203 }
204 
SetInputShape(const std::vector<char> & shape)205 void Ascend310DeviceInfo::SetInputShape(const std::vector<char> &shape) {
206   MS_EXCEPTION_IF_NULL(data_);
207   data_->params[kModelOptionAscend310InputShape] = CharToString(shape);
208 }
GetInputShapeChar() const209 std::vector<char> Ascend310DeviceInfo::GetInputShapeChar() const {
210   MS_EXCEPTION_IF_NULL(data_);
211   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310InputShape);
212   return StringToChar(ref);
213 }
214 
SetDynamicBatchSize(const std::vector<size_t> & dynamic_batch_size)215 void Ascend310DeviceInfo::SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size) {
216   MS_EXCEPTION_IF_NULL(data_);
217   std::string batchs = "";
218   for (size_t i = 0; i < dynamic_batch_size.size(); ++i) {
219     if (i != 0) {
220       batchs.push_back(',');
221     }
222     batchs += std::to_string(dynamic_batch_size[i]);
223   }
224   data_->params[kModelOptionAscend310DynamicBatchSize] = batchs;
225 }
GetDynamicBatchSizeChar() const226 std::vector<char> Ascend310DeviceInfo::GetDynamicBatchSizeChar() const {
227   MS_EXCEPTION_IF_NULL(data_);
228   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310DynamicBatchSize);
229   return StringToChar(ref);
230 }
231 
SetPrecisionMode(const std::vector<char> & precision_mode)232 void Ascend310DeviceInfo::SetPrecisionMode(const std::vector<char> &precision_mode) {
233   MS_EXCEPTION_IF_NULL(data_);
234   data_->params[kModelOptionAscend310PrecisionMode] = CharToString(precision_mode);
235 }
GetPrecisionModeChar() const236 std::vector<char> Ascend310DeviceInfo::GetPrecisionModeChar() const {
237   MS_EXCEPTION_IF_NULL(data_);
238   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310PrecisionMode);
239   return StringToChar(ref);
240 }
241 
SetOpSelectImplMode(const std::vector<char> & op_select_impl_mode)242 void Ascend310DeviceInfo::SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode) {
243   MS_EXCEPTION_IF_NULL(data_);
244   data_->params[kModelOptionAscend310OpSelectImplMode] = CharToString(op_select_impl_mode);
245 }
GetOpSelectImplModeChar() const246 std::vector<char> Ascend310DeviceInfo::GetOpSelectImplModeChar() const {
247   MS_EXCEPTION_IF_NULL(data_);
248   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310OpSelectImplMode);
249   return StringToChar(ref);
250 }
251 
SetFusionSwitchConfigPath(const std::vector<char> & cfg_path)252 void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::vector<char> &cfg_path) {
253   MS_EXCEPTION_IF_NULL(data_);
254   data_->params[KModelOptionAscend310FusionSwitchCfgPath] = CharToString(cfg_path);
255 }
GetFusionSwitchConfigPathChar() const256 std::vector<char> Ascend310DeviceInfo::GetFusionSwitchConfigPathChar() const {
257   MS_EXCEPTION_IF_NULL(data_);
258   const std::string &ref = GetValue<std::string>(data_, KModelOptionAscend310FusionSwitchCfgPath);
259   return StringToChar(ref);
260 }
261 
SetInputShapeMap(const std::map<int,std::vector<int>> & shape)262 void Ascend310DeviceInfo::SetInputShapeMap(const std::map<int, std::vector<int>> &shape) {
263   MS_EXCEPTION_IF_NULL(data_);
264   data_->params[kModelOptionAscend310InputShapeMap] = shape;
265 }
GetInputShapeMap() const266 std::map<int, std::vector<int>> Ascend310DeviceInfo::GetInputShapeMap() const {
267   MS_EXCEPTION_IF_NULL(data_);
268   return GetValue<std::map<int, std::vector<int>>>(data_, kModelOptionAscend310InputShapeMap);
269 }
270 
SetOutputType(enum DataType output_type)271 void Ascend310DeviceInfo::SetOutputType(enum DataType output_type) {
272   MS_EXCEPTION_IF_NULL(data_);
273   data_->params[kModelOptionAscend310OutputType] = output_type;
274 }
GetOutputType() const275 enum DataType Ascend310DeviceInfo::GetOutputType() const {
276   MS_EXCEPTION_IF_NULL(data_);
277   return GetValue<enum DataType>(data_, kModelOptionAscend310OutputType);
278 }
279 
SetBufferOptimizeMode(const std::vector<char> & buffer_optimize_mode)280 void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode) {
281   MS_EXCEPTION_IF_NULL(data_);
282   data_->params[kModelOptionAscend310BufferOptimize] = CharToString(buffer_optimize_mode);
283 }
GetBufferOptimizeModeChar() const284 std::vector<char> Ascend310DeviceInfo::GetBufferOptimizeModeChar() const {
285   MS_EXCEPTION_IF_NULL(data_);
286   const std::string &ref = GetValue<std::string>(data_, kModelOptionAscend310BufferOptimize);
287   return StringToChar(ref);
288 }
289 }  // namespace mindspore
290