1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H 17 #define MINDSPORE_INCLUDE_API_CONTEXT_H 18 19 #include <string> 20 #include <memory> 21 #include <vector> 22 #include <map> 23 #include "include/api/types.h" 24 #include "include/api/dual_abi_helper.h" 25 26 namespace mindspore { 27 enum DelegateMode { 28 kNoDelegate = 0, 29 kCoreML = 1, 30 kNNAPI = 2, 31 }; 32 33 enum DeviceType { 34 kCPU = 0, 35 kGPU, 36 kKirinNPU, 37 kAscend, 38 kAscend910, 39 kAscend310, 40 kCustomDevice, 41 kAllDevice, 42 //ohos-only device range[60,80) 43 kNNRt = 60, 44 // add new type here 45 kInvalidDeviceType = 100, 46 }; 47 48 class Allocator; 49 class AbstractDelegate; 50 class Delegate; 51 class DeviceInfoContext; 52 53 /// \brief Context is used to store environment variables during execution. 54 class MS_API Context { 55 public: 56 struct Data; 57 Context(); 58 ~Context() = default; 59 Context(const Context &rhs) : data_(rhs.data_) {} 60 61 /// \brief Set the number of threads at runtime. 62 /// 63 /// \param[in] thread_num the number of threads at runtime. 64 void SetThreadNum(int32_t thread_num); 65 66 /// \brief Get the current thread number setting. 67 /// 68 /// \return The current thread number setting. 69 int32_t GetThreadNum() const; 70 71 /// \brief Set the communication group info file path. 72 /// 73 /// \param[in] group_info_file communication group info file for distributed inference. 74 void SetGroupInfoFile(std::string group_info_file); 75 76 /// \brief Get the communication group info file path. 77 /// 78 /// \return The communication group info file path setting. 79 std::string GetGroupInfoFile() const; 80 81 /// \brief Set the parallel number of operators at runtime. 82 /// 83 /// \param[in] parallel_num the parallel number of operators at runtime. 84 void SetInterOpParallelNum(int32_t parallel_num); 85 86 /// \brief Get the current operators parallel number setting. 87 /// 88 /// \return The current operators parallel number setting. 89 int32_t GetInterOpParallelNum() const; 90 91 /// \brief Set the thread affinity to CPU cores. 92 /// 93 /// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first 94 void SetThreadAffinity(int mode); 95 96 /// \brief Get the thread affinity of CPU cores. 97 /// 98 /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first 99 int GetThreadAffinityMode() const; 100 101 /// \brief Set the thread lists to CPU cores. 102 /// 103 /// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the 104 /// mode is not effective. 105 /// 106 /// \param[in] core_list: a vector of thread core lists. 107 void SetThreadAffinity(const std::vector<int> &core_list); 108 109 /// \brief Get the thread lists of CPU cores. 110 /// 111 /// \return core_list: a vector of thread core lists. 112 std::vector<int32_t> GetThreadAffinityCoreList() const; 113 114 /// \brief Set the status whether to perform model inference or training in parallel. 115 /// 116 /// \param[in] is_parallel: true, parallel; false, not in parallel. 117 void SetEnableParallel(bool is_parallel); 118 119 /// \brief Get the status whether to perform model inference or training in parallel. 120 /// 121 /// \return Bool value that indicates whether in parallel. 122 bool GetEnableParallel() const; 123 124 /// \brief Set built-in delegate mode to access third-party AI framework. 125 /// 126 /// \param[in] mode the built-in delegate mode. 127 void SetBuiltInDelegate(DelegateMode mode); 128 129 /// \brief Get the built-in delegate mode of the third-party AI framework. 130 /// 131 /// \return the built-in delegate mode. 132 DelegateMode GetBuiltInDelegate() const; 133 134 /// \brief Set Delegate to access third-party AI framework. 135 /// 136 /// \param[in] delegate the custom delegate. 137 void set_delegate(const std::shared_ptr<AbstractDelegate> &delegate); 138 139 // deprecated 140 void SetDelegate(const std::shared_ptr<Delegate> &delegate); 141 142 /// \brief Get the delegate of the third-party AI framework. 143 /// 144 /// \return Pointer to the custom delegate. 145 std::shared_ptr<AbstractDelegate> get_delegate() const; 146 147 // deprecated 148 std::shared_ptr<Delegate> GetDelegate() const; 149 150 /// \brief Set quant model to run as float model in multi device. 151 /// 152 /// \param[in] float_mode: true, run as float model; false, not run as float model. 153 void SetMultiModalHW(bool float_mode); 154 155 /// \brief Get the mode of the model run. 156 /// 157 /// \return Bool value that indicates whether run as float model 158 bool GetMultiModalHW() const; 159 160 /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports 161 /// heterogeneous scenarios with multiple members in the vector. 162 /// 163 /// \return Mutable reference of DeviceInfoContext vector in this context. 164 std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo(); 165 166 private: 167 std::shared_ptr<Data> data_; 168 }; 169 170 /// \brief DeviceInfoContext defines different device contexts. 171 class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> { 172 public: 173 struct Data; 174 175 DeviceInfoContext(); 176 virtual ~DeviceInfoContext() = default; 177 178 /// \brief Get the type of this DeviceInfoContext. 179 /// 180 /// \return Type of this DeviceInfoContext. 181 virtual enum DeviceType GetDeviceType() const = 0; 182 183 /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts 184 /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails. 185 /// 186 /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr. 187 template <class T> 188 std::shared_ptr<T> Cast() { 189 static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type."); 190 if (GetDeviceType() != T().GetDeviceType()) { 191 return nullptr; 192 } 193 194 return std::static_pointer_cast<T>(shared_from_this()); 195 } 196 /// \brief obtain provider's name 197 /// 198 /// \return provider's name. 199 inline std::string GetProvider() const; 200 201 /// \brief set provider's name. 202 /// 203 /// \param[in] provider define the provider's name. 204 inline void SetProvider(const std::string &provider); 205 206 /// \brief obtain provider's device type. 207 /// 208 /// \return provider's device type. 209 inline std::string GetProviderDevice() const; 210 211 /// \brief set provider's device type. 212 /// 213 /// \param[in] device define the provider's device type.EG: CPU. 214 inline void SetProviderDevice(const std::string &device); 215 216 /// \brief set memory allocator. 217 /// 218 /// \param[in] allocator define the memory allocator which can be defined by user. 219 void SetAllocator(const std::shared_ptr<Allocator> &allocator); 220 221 /// \brief obtain memory allocator. 222 /// 223 /// \return memory allocator. 224 std::shared_ptr<Allocator> GetAllocator() const; 225 226 protected: 227 std::vector<char> GetProviderChar() const; 228 void SetProvider(const std::vector<char> &provider); 229 std::vector<char> GetProviderDeviceChar() const; 230 void SetProviderDevice(const std::vector<char> &device); 231 232 std::shared_ptr<Data> data_; 233 }; 234 235 std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); } 236 void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); } 237 std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); } 238 void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); } 239 240 /// \brief Derived from DeviceInfoContext, The configuration of the model running auto on the Host Devices, include 241 /// CPU/GPU/NPU/Ascend310/Ascend910. This option is only valid for MindSpore Lite. 242 class MS_API AutoDeviceInfo : public DeviceInfoContext { 243 public: 244 /// \brief Get the type of this DeviceInfoContext. 245 /// 246 /// \return Type of this DeviceInfoContext. 247 enum DeviceType GetDeviceType() const override { return DeviceType::kAllDevice; }; 248 }; 249 250 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid 251 /// for MindSpore Lite. 252 class MS_API CPUDeviceInfo : public DeviceInfoContext { 253 public: 254 /// \brief Get the type of this DeviceInfoContext. 255 /// 256 /// \return Type of this DeviceInfoContext. 257 enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; }; 258 259 /// \brief Set enables to perform the float16 inference 260 /// 261 /// \param[in] is_fp16 Enable float16 inference or not. 262 void SetEnableFP16(bool is_fp16); 263 264 /// \brief Get enables to perform the float16 inference 265 /// 266 /// \return Whether enable float16 inference. 267 bool GetEnableFP16() const; 268 }; 269 270 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid 271 /// for MindSpore Lite. 272 class MS_API KirinNPUDeviceInfo : public DeviceInfoContext { 273 public: 274 /// \brief Get the type of this DeviceInfoContext. 275 /// 276 /// \return Type of this DeviceInfoContext. 277 enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; }; 278 279 /// \brief Set enables to perform the float16 inference 280 /// 281 /// \param[in] is_fp16 Enable float16 inference or not. 282 void SetEnableFP16(bool is_fp16); 283 284 /// \brief Get enables to perform the float16 inference 285 /// 286 /// \return Whether enable float16 inference. 287 bool GetEnableFP16() const; 288 289 /// \brief Set the NPU frequency. 290 /// 291 /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme 292 /// performance), default as 3. 293 void SetFrequency(int frequency); 294 295 /// \brief Get the NPU frequency. 296 /// 297 /// \return NPU frequency 298 int GetFrequency() const; 299 }; 300 301 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU. 302 class MS_API GPUDeviceInfo : public DeviceInfoContext { 303 public: 304 /// \brief Get the type of this DeviceInfoContext. 305 /// 306 /// \return Type of this DeviceInfoContext. 307 enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; }; 308 309 /// \brief Set device id. 310 /// 311 /// \param[in] device_id The device id. 312 void SetDeviceID(uint32_t device_id); 313 314 /// \brief Get the device id. 315 /// 316 /// \return The device id. 317 uint32_t GetDeviceID() const; 318 319 /// \brief Get the distribution rank id. 320 /// 321 /// \return The device id. 322 int GetRankID() const; 323 324 /// \brief Get the distribution group size. 325 /// 326 /// \return The device id. 327 int GetGroupSize() const; 328 329 /// \brief Set the precision mode. 330 /// 331 /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default. 332 inline void SetPrecisionMode(const std::string &precision_mode); 333 334 /// \brief Get the precision mode. 335 /// 336 /// \return The precision mode. 337 inline std::string GetPrecisionMode() const; 338 339 /// \brief Set enables to perform the float16 inference 340 /// 341 /// \param[in] is_fp16 Enable float16 inference or not. 342 void SetEnableFP16(bool is_fp16); 343 344 /// \brief Get enables to perform the float16 inference 345 /// 346 /// \return Whether enable float16 inference. 347 bool GetEnableFP16() const; 348 349 /// \brief Set enables to sharing mem with OpenGL 350 /// 351 /// \param[in] is_enable_gl_texture Enable sharing OpenCL Memory with OpenGL or not. 352 void SetEnableGLTexture(bool is_enable_gl_texture); 353 354 /// \brief Get enables to sharing mem with OpenGL 355 /// 356 /// \return Whether enable sharing mem with OpenGL. 357 bool GetEnableGLTexture() const; 358 359 /// \brief Set current OpenGL context 360 /// 361 /// \param[in] gl_context Current OpenGL context. 362 void SetGLContext(void *gl_context); 363 364 /// \brief Get current OpenGL context 365 /// 366 /// \return the OpenCL context by OpenGL used. 367 void *GetGLContext() const; 368 369 /// \brief Set current OpenGL display 370 /// 371 /// \param[in] gl_display Current OpenGL display. 372 void SetGLDisplay(void *gl_display); 373 374 /// \brief Get current OpenGL display 375 /// 376 /// \return the OpenCL display by OpenGL used. 377 void *GetGLDisplay() const; 378 379 private: 380 void SetPrecisionMode(const std::vector<char> &precision_mode); 381 std::vector<char> GetPrecisionModeChar() const; 382 }; 383 384 void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { 385 SetPrecisionMode(StringToChar(precision_mode)); 386 } 387 std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } 388 389 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend. This option is 390 /// invalid for MindSpore Lite. 391 class MS_API AscendDeviceInfo : public DeviceInfoContext { 392 public: 393 /// \brief Get the type of this DeviceInfoContext. 394 /// 395 /// \return Type of this DeviceInfoContext. 396 enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; }; 397 398 /// \brief Set device id. 399 /// 400 /// \param[in] device_id The device id. 401 void SetDeviceID(uint32_t device_id); 402 403 /// \brief Get the device id. 404 /// 405 /// \return The device id. 406 uint32_t GetDeviceID() const; 407 408 /// \brief Set the distribution rank id. 409 /// 410 /// \param[in] rank_id The rank id. 411 void SetRankID(uint32_t rank_id); 412 413 /// \brief Get the distribution rank id. 414 /// 415 /// \return The rank id. 416 uint32_t GetRankID() const; 417 418 /// \brief Set AIPP configuration file path. 419 /// 420 /// \param[in] cfg_path AIPP configuration file path. 421 inline void SetInsertOpConfigPath(const std::string &cfg_path); 422 423 /// \brief Get AIPP configuration file path. 424 /// 425 /// \return AIPP configuration file path. 426 inline std::string GetInsertOpConfigPath() const; 427 428 /// \brief Set format of model inputs. 429 /// 430 /// \param[in] format Optional "NCHW", "NHWC", and "ND". 431 inline void SetInputFormat(const std::string &format); 432 433 /// \brief Get format of model inputs. 434 /// 435 /// \return The format of model inputs. 436 inline std::string GetInputFormat() const; 437 438 /// \brief Set shape of model inputs. 439 /// 440 /// \param[in] shape e.g. "input_op_name1:1,2,3,4;input_op_name2:4,3,2,1". 441 inline void SetInputShape(const std::string &shape); 442 443 /// \brief Get shape of model inputs. 444 /// 445 /// \return The shape of model inputs. 446 inline std::string GetInputShape() const; 447 448 /// \brief Set shape of model inputs. 449 /// 450 /// \param[in] shape e.g. {{0, {1,2,3,4}}, {1, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input 451 /// shape 4,3,2,1. 452 void SetInputShapeMap(const std::map<int, std::vector<int>> &shape); 453 454 /// \brief Get shape of model inputs. 455 /// 456 /// \return The shape of model inputs. 457 std::map<int, std::vector<int>> GetInputShapeMap() const; 458 459 /// \brief Set dynamic batch sizes of model inputs. Ranges from 2 to 100. 460 /// 461 /// \param[in] dynamic_batch_size e.g. {1, 2} means batch size 1 and 2 are configured. 462 void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size); 463 464 /// \brief Get dynamic batch sizes of model inputs. 465 /// 466 /// \return The dynamic batch sizes of model inputs in string format. 467 inline std::string GetDynamicBatchSize() const; 468 469 /// \brief Set the dynamic image size of model inputs. 470 /// 471 /// \param[in] dynamic_image_size size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64. 472 inline void SetDynamicImageSize(const std::string &dynamic_image_size); 473 474 /// \brief Get dynamic image size of model inputs. 475 /// 476 /// \return The image size of model inputs. 477 inline std::string GetDynamicImageSize() const; 478 479 /// \brief Set type of model outputs. 480 /// 481 /// \param[in] output_type FP32, UINT8 or FP16. 482 void SetOutputType(enum DataType output_type); 483 484 /// \brief Get type of model outputs. 485 /// 486 /// \return The set type of model outputs. 487 enum DataType GetOutputType() const; 488 489 /// \brief Set precision mode of model. 490 /// 491 /// \param[in] precision_mode Optional "enforce_fp16", "preferred_fp32", "enforce_origin", "enforce_fp32" and 492 /// "preferred_optimal", "enforce_fp16" is set as default 493 inline void SetPrecisionMode(const std::string &precision_mode); 494 495 /// \brief Get precision mode of model. 496 /// 497 /// \return The set type of model outputs 498 inline std::string GetPrecisionMode() const; 499 500 /// \brief Set op select implementation mode. 501 /// 502 /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as 503 /// default. 504 inline void SetOpSelectImplMode(const std::string &op_select_impl_mode); 505 506 /// \brief Get op select implementation mode. 507 /// 508 /// \return The set op select implementation mode. 509 inline std::string GetOpSelectImplMode() const; 510 511 /// \brief Set fusion switch config file path. Controls which fusion passes to be turned off. 512 /// 513 /// \param[in] cfg_path fusion switch config file path. 514 inline void SetFusionSwitchConfigPath(const std::string &cfg_path); 515 516 /// \brief Get fusion switch config file path. 517 /// 518 /// \return The fusion switch config file path. 519 inline std::string GetFusionSwitchConfigPath() const; 520 521 /// \brief Set buffer optimize mode. 522 /// 523 /// \param[in] buffer_optimize_mode Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", 524 /// default as "l2_optimize". 525 inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode); 526 527 /// \brief Get buffer optimize mode. 528 /// 529 /// \return The buffer optimize mode. 530 inline std::string GetBufferOptimizeMode() const; 531 532 private: 533 void SetInsertOpConfigPath(const std::vector<char> &cfg_path); 534 std::vector<char> GetInsertOpConfigPathChar() const; 535 536 void SetInputFormat(const std::vector<char> &format); 537 std::vector<char> GetInputFormatChar() const; 538 539 void SetInputShape(const std::vector<char> &shape); 540 std::vector<char> GetInputShapeChar() const; 541 542 std::vector<char> GetDynamicBatchSizeChar() const; 543 544 void SetDynamicImageSize(const std::vector<char> &dynamic_image_size); 545 std::vector<char> GetDynamicImageSizeChar() const; 546 547 void SetPrecisionMode(const std::vector<char> &precision_mode); 548 std::vector<char> GetPrecisionModeChar() const; 549 550 void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode); 551 std::vector<char> GetOpSelectImplModeChar() const; 552 553 void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path); 554 std::vector<char> GetFusionSwitchConfigPathChar() const; 555 556 void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode); 557 std::vector<char> GetBufferOptimizeModeChar() const; 558 }; 559 560 using Ascend310DeviceInfo = AscendDeviceInfo; 561 using Ascend910DeviceInfo = AscendDeviceInfo; 562 563 void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { 564 SetInsertOpConfigPath(StringToChar(cfg_path)); 565 } 566 std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); } 567 568 void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); } 569 std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); } 570 571 void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); } 572 std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); } 573 574 std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); } 575 576 void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) { 577 SetDynamicImageSize(StringToChar(dynamic_image_size)); 578 } 579 580 std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); } 581 582 void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) { 583 SetPrecisionMode(StringToChar(precision_mode)); 584 } 585 std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); } 586 587 void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) { 588 SetOpSelectImplMode(StringToChar(op_select_impl_mode)); 589 } 590 std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); } 591 592 void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) { 593 SetFusionSwitchConfigPath(StringToChar(cfg_path)); 594 } 595 std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const { 596 return CharToString(GetFusionSwitchConfigPathChar()); 597 } 598 599 void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) { 600 SetBufferOptimizeMode(StringToChar(buffer_optimize_mode)); 601 } 602 std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); } 603 604 struct Extension { 605 std::string name; 606 std::vector<uint8_t> value; 607 }; 608 609 class MS_API NNRTDeviceInfo : public DeviceInfoContext { 610 public: 611 /// \brief Get the type of this DeviceInfoContext. 612 /// 613 /// \return Type of this DeviceInfoContext. 614 enum DeviceType GetDeviceType() const override { return DeviceType::kNNRt; }; 615 616 /// \brief Set device id. 617 /// 618 /// \param[in] device_id The device id. 619 void SetDeviceID(size_t device_id); 620 621 /// \brief Get the device id. 622 /// 623 /// \return The device id. 624 size_t GetDeviceID() const; 625 626 /// \brief Set performance mode. 627 /// 628 /// \param[in] performance_mode The performance mode. 629 void SetPerformanceMode(int performance_mode); 630 631 /// \brief Get performance mode. 632 /// 633 /// \return The priority. 634 int GetPerformanceMode() const; 635 636 /// \brief Set priority. 637 /// 638 /// \param[in] priority The priority. 639 void SetPriority(int priority); 640 641 /// \brief Get priority. 642 /// 643 /// \return The priority. 644 int GetPriority() const; 645 646 /// \brief Set enables to perform the float16 inference 647 /// 648 /// \param[in] is_fp16 Enable float16 inference or not. 649 void SetEnableFP16(bool is_fp16); 650 651 /// \brief Get enables to perform the float16 inference 652 /// 653 /// \return Whether enable float16 inference. 654 bool GetEnableFP16() const; 655 656 /// \brief Set extensions 657 /// 658 /// \param[in] extension array. 659 void SetExtensions(const std::vector<Extension> &extensions); 660 661 /// \brief Get extensions 662 /// 663 /// \return extension array. 664 std::vector<Extension> GetExtensions() const; 665 }; 666 } // namespace mindspore 667 #endif // MINDSPORE_INCLUDE_API_CONTEXT_H 668