1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H 17 #define MINDSPORE_INCLUDE_API_CONTEXT_H 18 19 #include <string> 20 #include <memory> 21 #include <vector> 22 #include <map> 23 #include "include/api/types.h" 24 #include "include/api/dual_abi_helper.h" 25 26 namespace mindspore { 27 enum DeviceType { 28 kCPU = 0, 29 kGPU, 30 kKirinNPU, 31 kAscend910, 32 kAscend310, 33 // add new type here 34 //ohos-only device range[60,80) 35 kNNRt = 60, 36 kInvalidDeviceType = 100, 37 }; 38 39 class Allocator; 40 class Delegate; 41 class DeviceInfoContext; 42 43 /// \brief Context is used to store environment variables during execution. 44 class MS_API Context { 45 public: 46 struct Data; 47 Context(); 48 ~Context() = default; 49 50 /// \brief Set the number of threads at runtime. Only valid for Lite. 51 /// 52 /// \param[in] thread_num the number of threads at runtime. 53 void SetThreadNum(int32_t thread_num); 54 55 /// \brief Get the current thread number setting. Only valid for Lite. 56 /// 57 /// \return The current thread number setting. 58 int32_t GetThreadNum() const; 59 60 /// \brief Set the thread affinity to CPU cores. Only valid for Lite. 61 /// 62 /// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first 63 void SetThreadAffinity(int mode); 64 65 /// \brief Get the thread affinity of CPU cores. Only valid for Lite. 66 /// 67 /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first 68 int GetThreadAffinityMode() const; 69 70 /// \brief Set the thread lists to CPU cores. Only valid for Lite. 71 /// 72 /// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the 73 /// mode is not effective. 74 /// 75 /// \param[in] core_list: a vector of thread core lists. 76 void SetThreadAffinity(const std::vector<int> &core_list); 77 78 /// \brief Get the thread lists of CPU cores. Only valid for Lite. 79 /// 80 /// \return core_list: a vector of thread core lists. 81 std::vector<int32_t> GetThreadAffinityCoreList() const; 82 83 /// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite. 84 /// 85 /// \param[in] is_parallel: true, parallel; false, not in parallel. 86 void SetEnableParallel(bool is_parallel); 87 88 /// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite. 89 /// 90 /// \return Bool value that indicates whether in parallel. 91 bool GetEnableParallel() const; 92 93 /// \brief Set Delegate to access third-party AI framework. Only valid for Lite. 94 /// 95 /// \param[in] Pointer to the custom delegate. 96 void SetDelegate(const std::shared_ptr<Delegate> &delegate); 97 98 /// \brief Get the delegate of the third-party AI framework. Only valid for Lite. 99 /// 100 /// \return Pointer to the custom delegate. 101 std::shared_ptr<Delegate> GetDelegate() const; 102 103 /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports 104 /// heterogeneous scenarios with multiple members in the vector. 105 /// 106 /// \return Mutable reference of DeviceInfoContext vector in this context. 107 std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo(); 108 109 private: 110 std::shared_ptr<Data> data_; 111 }; 112 113 /// \brief DeviceInfoContext defines different device contexts. 114 class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> { 115 public: 116 struct Data; 117 118 DeviceInfoContext(); 119 virtual ~DeviceInfoContext() = default; 120 121 /// \brief Get the type of this DeviceInfoContext. 122 /// 123 /// \return Type of this DeviceInfoContext. 124 virtual enum DeviceType GetDeviceType() const = 0; 125 126 /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts 127 /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails. 128 /// 129 /// \param T Type 130 /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr. 131 template <class T> 132 std::shared_ptr<T> Cast() { 133 static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type."); 134 if (GetDeviceType() != T().GetDeviceType()) { 135 return nullptr; 136 } 137 138 return std::static_pointer_cast<T>(shared_from_this()); 139 } 140 /// \brief obtain provider's name 141 /// 142 /// \return provider's name. 143 std::string GetProvider() const; 144 /// \brief set provider's name. 145 /// 146 /// \param[in] provider define the provider's name. 147 148 void SetProvider(const std::string &provider); 149 /// \brief obtain provider's device type. 150 /// 151 /// \return provider's device type. 152 153 std::string GetProviderDevice() const; 154 /// \brief set provider's device type. 155 /// 156 /// \param[in] device define the provider's device type.EG: CPU. 157 void SetProviderDevice(const std::string &device); 158 159 /// \brief set memory allocator. 160 /// 161 /// \param[in] allocator define the memory allocator which can be defined by user. 162 void SetAllocator(const std::shared_ptr<Allocator> &allocator); 163 164 /// \brief obtain memory allocator. 165 /// 166 /// \return memory allocator. 167 std::shared_ptr<Allocator> GetAllocator() const; 168 169 protected: 170 std::shared_ptr<Data> data_; 171 }; 172 173 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid 174 /// for MindSpore Lite. 175 class MS_API CPUDeviceInfo : public DeviceInfoContext { 176 public: 177 /// \brief Get the type of this DeviceInfoContext. 178 /// 179 /// \return Type of this DeviceInfoContext. 180 enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; 181 182 /// \brief Set enables to perform the float16 inference 183 /// 184 /// \param[in] is_fp16 Enable float16 inference or not. 185 void SetEnableFP16(bool is_fp16); 186 187 /// \brief Get enables to perform the float16 inference 188 /// 189 /// \return Whether enable float16 inference. 190 bool GetEnableFP16() const; 191 }; 192 193 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid 194 /// for MindSpore Lite. 195 class MS_API KirinNPUDeviceInfo : public DeviceInfoContext { 196 public: 197 /// \brief Get the type of this DeviceInfoContext. 198 /// 199 /// \return Type of this DeviceInfoContext. 200 enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; 201 202 /// \brief Set the NPU frequency. 203 /// 204 /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme 205 /// performance), default as 3. 206 void SetFrequency(int frequency); 207 208 /// \brief Get the NPU frequency. 209 /// 210 /// \return NPU frequency 211 int GetFrequency() const; 212 }; 213 214 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU. 215 class MS_API GPUDeviceInfo : public DeviceInfoContext { 216 public: 217 /// \brief Get the type of this DeviceInfoContext. 218 /// 219 /// \return Type of this DeviceInfoContext. 220 enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; 221 222 /// \brief Set device id. 223 /// 224 /// \param[in] device_id The device id. 225 void SetDeviceID(uint32_t device_id); 226 227 /// \brief Get the device id. 228 /// 229 /// \return The device id. 230 uint32_t GetDeviceID() const; 231 232 /// \brief Set the precision mode. 233 /// 234 /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default. 235 inline void SetPrecisionMode(const std::string &precision_mode); 236 237 /// \brief Get the precision mode. 238 /// 239 /// \return The precision mode. 240 inline std::string GetPrecisionMode() const; 241 242 /// \brief Set enables to perform the float16 inference 243 /// 244 /// \param[in] is_fp16 Enable float16 inference or not. 245 void SetEnableFP16(bool is_fp16); 246 247 /// \brief Get enables to perform the float16 inference 248 /// 249 /// \return Whether enable float16 inference. 250 bool GetEnableFP16() const; 251 252 private: 253 void SetPrecisionMode(const std::vector<char> &precision_mode); 254 std::vector<char> GetPrecisionModeChar() const; 255 }; 256 257 void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { 258 SetPrecisionMode(StringToChar(precision_mode)); 259 } 260 std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } 261 262 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is 263 /// invalid for MindSpore Lite. 264 class MS_API Ascend910DeviceInfo : public DeviceInfoContext { 265 public: 266 /// \brief Get the type of this DeviceInfoContext. 267 /// 268 /// \return Type of this DeviceInfoContext. 269 enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; }; 270 271 /// \brief Set device id. 272 /// 273 /// \param[in] device_id The device id. 274 void SetDeviceID(uint32_t device_id); 275 276 /// \brief Get the device id. 277 /// 278 /// \return The device id. 279 uint32_t GetDeviceID() const; 280 }; 281 282 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is 283 /// invalid for MindSpore Lite. 284 class MS_API Ascend310DeviceInfo : public DeviceInfoContext { 285 public: 286 /// \brief Get the type of this DeviceInfoContext. 287 /// 288 /// \return Type of this DeviceInfoContext. 289 enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; }; 290 291 /// \brief Set device id. 292 /// 293 /// \param[in] device_id The device id. 294 void SetDeviceID(uint32_t device_id); 295 296 /// \brief Get the device id. 297 /// 298 /// \return The device id. 299 uint32_t GetDeviceID() const; 300 301 /// \brief Set AIPP configuration file path. 302 /// 303 /// \param[in] cfg_path AIPP configuration file path. 304 inline void SetInsertOpConfigPath(const std::string &cfg_path); 305 306 /// \brief Get AIPP configuration file path. 307 /// 308 /// \return AIPP configuration file path. 309 inline std::string GetInsertOpConfigPath() const; 310 311 /// \brief Set format of model inputs. 312 /// 313 /// \param[in] format Optional "NCHW", "NHWC", etc. 314 inline void SetInputFormat(const std::string &format); 315 316 /// \brief Get format of model inputs. 317 /// 318 /// \return The format of model inputs. 319 inline std::string GetInputFormat() const; 320 321 /// \brief Set shape of model inputs. 322 /// 323 /// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1". 324 inline void SetInputShape(const std::string &shape); 325 326 /// \brief Get shape of model inputs. 327 /// 328 /// \return The shape of model inputs. 329 inline std::string GetInputShape() const; 330 331 /// \brief Set shape of model inputs. 332 /// 333 /// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input 334 /// shape 4,3,2,1. 335 void SetInputShapeMap(const std::map<int, std::vector<int>> &shape); 336 337 /// \brief Get shape of model inputs. 338 /// 339 /// \return The shape of model inputs. 340 std::map<int, std::vector<int>> GetInputShapeMap() const; 341 342 void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size); 343 inline std::string GetDynamicBatchSize() const; 344 345 /// \brief Set type of model outputs. 346 /// 347 /// \param[in] output_type FP32, UINT8 or FP16, default as FP32. 348 void SetOutputType(enum DataType output_type); 349 350 /// \brief Get type of model outputs. 351 /// 352 /// \return The set type of model outputs. 353 enum DataType GetOutputType() const; 354 355 /// \brief Set precision mode of model. 356 /// 357 /// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and 358 /// "allow_mix_precision", "force_fp16" is set as default 359 inline void SetPrecisionMode(const std::string &precision_mode); 360 361 /// \brief Get precision mode of model. 362 /// 363 /// \return The set type of model outputs 364 inline std::string GetPrecisionMode() const; 365 366 /// \brief Set op select implementation mode. 367 /// 368 /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as 369 /// default. 370 inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); 371 372 /// \brief Get op select implementation mode. 373 /// 374 /// \return The set op select implementation mode. 375 inline std::string GetOpSelectImplMode() const; 376 377 inline void SetFusionSwitchConfigPath(const std::string &cfg_path); 378 inline std::string GetFusionSwitchConfigPath() const; 379 380 // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize" 381 inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); 382 inline std::string GetBufferOptimizeMode() const; 383 384 private: 385 void SetInsertOpConfigPath(const std::vector<char> &cfg_path); 386 std::vector<char> GetInsertOpConfigPathChar() const; 387 388 void SetInputFormat(const std::vector<char> &format); 389 std::vector<char> GetInputFormatChar() const; 390 391 void SetInputShape(const std::vector<char> &shape); 392 std::vector<char> GetInputShapeChar() const; 393 394 std::vector<char> GetDynamicBatchSizeChar() const; 395 396 void SetPrecisionMode(const std::vector<char> &precision_mode); 397 std::vector<char> GetPrecisionModeChar() const; 398 399 void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode); 400 std::vector<char> GetOpSelectImplModeChar() const; 401 402 void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path); 403 std::vector<char> GetFusionSwitchConfigPathChar() const; 404 405 void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode); 406 std::vector<char> GetBufferOptimizeModeChar() const; 407 }; 408 409 void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { 410 SetInsertOpConfigPath(StringToChar(cfg_path)); 411 } 412 std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } 413 414 void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } 415 std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } 416 417 void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } 418 std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } 419 420 std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } 421 422 void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) { 423 SetPrecisionMode(StringToChar(precision_mode)); 424 } 425 std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } 426 427 void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { 428 SetOpSelectImplMode(StringToChar(op_select_impl_mode)); 429 } 430 std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } 431 432 void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { 433 SetFusionSwitchConfigPath(StringToChar(cfg_path)); 434 } 435 std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const { 436 return CharToString(GetFusionSwitchConfigPathChar()); 437 } 438 439 void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { 440 SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); 441 } 442 std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } 443 444 class MS_API NNRTDeviceInfo : public DeviceInfoContext { 445 public: 446 /// \brief Get the type of this DeviceInfoContext. 447 /// 448 /// \return Type of this DeviceInfoContext. 449 enum DeviceType GetDeviceType() const override { return DeviceType::kNNRt; }; 450 }; 451 452 } // namespace mindspore 453 #endif // MINDSPORE_INCLUDE_API_CONTEXT_H 454