1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
17 #define MINDSPORE_INCLUDE_API_CONTEXT_H
18
19 #include <string>
20 #include <memory>
21 #include <vector>
22 #include <map>
23 #include "include/api/types.h"
24 #include "include/api/dual_abi_helper.h"
25
26 namespace mindspore {
27 enum DelegateMode {
28 kNoDelegate = 0,
29 kCoreML = 1,
30 kNNAPI = 2,
31 };
32
33 enum DeviceType {
34 kCPU = 0,
35 kGPU,
36 kKirinNPU,
37 kAscend,
38 kAscend910,
39 kAscend310,
40 kCustomDevice,
41 kAllDevice,
42 //ohos-only device range[60,80)
43 kNNRt = 60,
44 // add new type here
45 kInvalidDeviceType = 100,
46 };
47
48 class Allocator;
49 class AbstractDelegate;
50 class Delegate;
51 class DeviceInfoContext;
52
53 /// \brief Context is used to store environment variables during execution.
54 class MS_API Context {
55 public:
56 struct Data;
57 Context();
58 ~Context() = default;
Context(const Context & rhs)59 Context(const Context &rhs) : data_(rhs.data_) {}
60
61 /// \brief Set the number of threads at runtime.
62 ///
63 /// \param[in] thread_num the number of threads at runtime.
64 void SetThreadNum(int32_t thread_num);
65
66 /// \brief Get the current thread number setting.
67 ///
68 /// \return The current thread number setting.
69 int32_t GetThreadNum() const;
70
71 /// \brief Set the communication group info file path.
72 ///
73 /// \param[in] group_info_file communication group info file for distributed inference.
74 void SetGroupInfoFile(std::string group_info_file);
75
76 /// \brief Get the communication group info file path.
77 ///
78 /// \return The communication group info file path setting.
79 std::string GetGroupInfoFile() const;
80
81 /// \brief Set the parallel number of operators at runtime.
82 ///
83 /// \param[in] parallel_num the parallel number of operators at runtime.
84 void SetInterOpParallelNum(int32_t parallel_num);
85
86 /// \brief Get the current operators parallel number setting.
87 ///
88 /// \return The current operators parallel number setting.
89 int32_t GetInterOpParallelNum() const;
90
91 /// \brief Set the thread affinity to CPU cores.
92 ///
93 /// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
94 void SetThreadAffinity(int mode);
95
96 /// \brief Get the thread affinity of CPU cores.
97 ///
98 /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
99 int GetThreadAffinityMode() const;
100
101 /// \brief Set the thread lists to CPU cores.
102 ///
103 /// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the
104 /// mode is not effective.
105 ///
106 /// \param[in] core_list: a vector of thread core lists.
107 void SetThreadAffinity(const std::vector<int> &core_list);
108
109 /// \brief Get the thread lists of CPU cores.
110 ///
111 /// \return core_list: a vector of thread core lists.
112 std::vector<int32_t> GetThreadAffinityCoreList() const;
113
114 /// \brief Set the status whether to perform model inference or training in parallel.
115 ///
116 /// \param[in] is_parallel: true, parallel; false, not in parallel.
117 void SetEnableParallel(bool is_parallel);
118
119 /// \brief Get the status whether to perform model inference or training in parallel.
120 ///
121 /// \return Bool value that indicates whether in parallel.
122 bool GetEnableParallel() const;
123
124 /// \brief Set built-in delegate mode to access third-party AI framework.
125 ///
126 /// \param[in] mode the built-in delegate mode.
127 void SetBuiltInDelegate(DelegateMode mode);
128
129 /// \brief Get the built-in delegate mode of the third-party AI framework.
130 ///
131 /// \return the built-in delegate mode.
132 DelegateMode GetBuiltInDelegate() const;
133
134 /// \brief Set Delegate to access third-party AI framework.
135 ///
136 /// \param[in] delegate the custom delegate.
137 void set_delegate(const std::shared_ptr<AbstractDelegate> &delegate);
138
139 // deprecated
140 void SetDelegate(const std::shared_ptr<Delegate> &delegate);
141
142 /// \brief Get the delegate of the third-party AI framework.
143 ///
144 /// \return Pointer to the custom delegate.
145 std::shared_ptr<AbstractDelegate> get_delegate() const;
146
147 // deprecated
148 std::shared_ptr<Delegate> GetDelegate() const;
149
150 /// \brief Set quant model to run as float model in multi device.
151 ///
152 /// \param[in] float_mode: true, run as float model; false, not run as float model.
153 void SetMultiModalHW(bool float_mode);
154
155 /// \brief Get the mode of the model run.
156 ///
157 /// \return Bool value that indicates whether run as float model
158 bool GetMultiModalHW() const;
159
160 /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
161 /// heterogeneous scenarios with multiple members in the vector.
162 ///
163 /// \return Mutable reference of DeviceInfoContext vector in this context.
164 std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
165
166 private:
167 std::shared_ptr<Data> data_;
168 };
169
170 /// \brief DeviceInfoContext defines different device contexts.
171 class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
172 public:
173 struct Data;
174
175 DeviceInfoContext();
176 virtual ~DeviceInfoContext() = default;
177
178 /// \brief Get the type of this DeviceInfoContext.
179 ///
180 /// \return Type of this DeviceInfoContext.
181 virtual enum DeviceType GetDeviceType() const = 0;
182
183 /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
184 /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
185 ///
186 /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
187 template <class T>
Cast()188 std::shared_ptr<T> Cast() {
189 static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
190 if (GetDeviceType() != T().GetDeviceType()) {
191 return nullptr;
192 }
193
194 return std::static_pointer_cast<T>(shared_from_this());
195 }
196 /// \brief obtain provider's name
197 ///
198 /// \return provider's name.
199 inline std::string GetProvider() const;
200
201 /// \brief set provider's name.
202 ///
203 /// \param[in] provider define the provider's name.
204 inline void SetProvider(const std::string &provider);
205
206 /// \brief obtain provider's device type.
207 ///
208 /// \return provider's device type.
209 inline std::string GetProviderDevice() const;
210
211 /// \brief set provider's device type.
212 ///
213 /// \param[in] device define the provider's device type.EG: CPU.
214 inline void SetProviderDevice(const std::string &device);
215
216 /// \brief set memory allocator.
217 ///
218 /// \param[in] allocator define the memory allocator which can be defined by user.
219 void SetAllocator(const std::shared_ptr<Allocator> &allocator);
220
221 /// \brief obtain memory allocator.
222 ///
223 /// \return memory allocator.
224 std::shared_ptr<Allocator> GetAllocator() const;
225
226 protected:
227 std::vector<char> GetProviderChar() const;
228 void SetProvider(const std::vector<char> &provider);
229 std::vector<char> GetProviderDeviceChar() const;
230 void SetProviderDevice(const std::vector<char> &device);
231
232 std::shared_ptr<Data> data_;
233 };
234
GetProvider()235 std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); }
SetProvider(const std::string & provider)236 void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
GetProviderDevice()237 std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); }
SetProviderDevice(const std::string & device)238 void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); }
239
240 /// \brief Derived from DeviceInfoContext, The configuration of the model running auto on the Host Devices, include
241 /// CPU/GPU/NPU/Ascend310/Ascend910. This option is only valid for MindSpore Lite.
242 class MS_API AutoDeviceInfo : public DeviceInfoContext {
243 public:
244 /// \brief Get the type of this DeviceInfoContext.
245 ///
246 /// \return Type of this DeviceInfoContext.
GetDeviceType()247 enum DeviceType GetDeviceType() const override { return DeviceType::kAllDevice; };
248 };
249
250 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
251 /// for MindSpore Lite.
252 class MS_API CPUDeviceInfo : public DeviceInfoContext {
253 public:
254 /// \brief Get the type of this DeviceInfoContext.
255 ///
256 /// \return Type of this DeviceInfoContext.
GetDeviceType()257 enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
258
259 /// \brief Set enables to perform the float16 inference
260 ///
261 /// \param[in] is_fp16 Enable float16 inference or not.
262 void SetEnableFP16(bool is_fp16);
263
264 /// \brief Get enables to perform the float16 inference
265 ///
266 /// \return Whether enable float16 inference.
267 bool GetEnableFP16() const;
268 };
269
270 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
271 /// for MindSpore Lite.
272 class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
273 public:
274 /// \brief Get the type of this DeviceInfoContext.
275 ///
276 /// \return Type of this DeviceInfoContext.
GetDeviceType()277 enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
278
279 /// \brief Set enables to perform the float16 inference
280 ///
281 /// \param[in] is_fp16 Enable float16 inference or not.
282 void SetEnableFP16(bool is_fp16);
283
284 /// \brief Get enables to perform the float16 inference
285 ///
286 /// \return Whether enable float16 inference.
287 bool GetEnableFP16() const;
288
289 /// \brief Set the NPU frequency.
290 ///
291 /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
292 /// performance), default as 3.
293 void SetFrequency(int frequency);
294
295 /// \brief Get the NPU frequency.
296 ///
297 /// \return NPU frequency
298 int GetFrequency() const;
299 };
300
301 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
302 class MS_API GPUDeviceInfo : public DeviceInfoContext {
303 public:
304 /// \brief Get the type of this DeviceInfoContext.
305 ///
306 /// \return Type of this DeviceInfoContext.
GetDeviceType()307 enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
308
309 /// \brief Set device id.
310 ///
311 /// \param[in] device_id The device id.
312 void SetDeviceID(uint32_t device_id);
313
314 /// \brief Get the device id.
315 ///
316 /// \return The device id.
317 uint32_t GetDeviceID() const;
318
319 /// \brief Get the distribution rank id.
320 ///
321 /// \return The device id.
322 int GetRankID() const;
323
324 /// \brief Get the distribution group size.
325 ///
326 /// \return The device id.
327 int GetGroupSize() const;
328
329 /// \brief Set the precision mode.
330 ///
331 /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default.
332 inline void SetPrecisionMode(const std::string &precision_mode);
333
334 /// \brief Get the precision mode.
335 ///
336 /// \return The precision mode.
337 inline std::string GetPrecisionMode() const;
338
339 /// \brief Set enables to perform the float16 inference
340 ///
341 /// \param[in] is_fp16 Enable float16 inference or not.
342 void SetEnableFP16(bool is_fp16);
343
344 /// \brief Get enables to perform the float16 inference
345 ///
346 /// \return Whether enable float16 inference.
347 bool GetEnableFP16() const;
348
349 /// \brief Set enables to sharing mem with OpenGL
350 ///
351 /// \param[in] is_enable_gl_texture Enable sharing OpenCL Memory with OpenGL or not.
352 void SetEnableGLTexture(bool is_enable_gl_texture);
353
354 /// \brief Get enables to sharing mem with OpenGL
355 ///
356 /// \return Whether enable sharing mem with OpenGL.
357 bool GetEnableGLTexture() const;
358
359 /// \brief Set current OpenGL context
360 ///
361 /// \param[in] gl_context Current OpenGL context.
362 void SetGLContext(void *gl_context);
363
364 /// \brief Get current OpenGL context
365 ///
366 /// \return the OpenCL context by OpenGL used.
367 void *GetGLContext() const;
368
369 /// \brief Set current OpenGL display
370 ///
371 /// \param[in] gl_display Current OpenGL display.
372 void SetGLDisplay(void *gl_display);
373
374 /// \brief Get current OpenGL display
375 ///
376 /// \return the OpenCL display by OpenGL used.
377 void *GetGLDisplay() const;
378
379 private:
380 void SetPrecisionMode(const std::vector<char> &precision_mode);
381 std::vector<char> GetPrecisionModeChar() const;
382 };
383
SetPrecisionMode(const std::string & precision_mode)384 void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
385 SetPrecisionMode(StringToChar(precision_mode));
386 }
GetPrecisionMode()387 std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
388
389 /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend. This option is
390 /// invalid for MindSpore Lite.
391 class MS_API AscendDeviceInfo : public DeviceInfoContext {
392 public:
393 /// \brief Get the type of this DeviceInfoContext.
394 ///
395 /// \return Type of this DeviceInfoContext.
GetDeviceType()396 enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
397
398 /// \brief Set device id.
399 ///
400 /// \param[in] device_id The device id.
401 void SetDeviceID(uint32_t device_id);
402
403 /// \brief Get the device id.
404 ///
405 /// \return The device id.
406 uint32_t GetDeviceID() const;
407
408 /// \brief Set the distribution rank id.
409 ///
410 /// \param[in] rank_id The rank id.
411 void SetRankID(uint32_t rank_id);
412
413 /// \brief Get the distribution rank id.
414 ///
415 /// \return The rank id.
416 uint32_t GetRankID() const;
417
418 /// \brief Set AIPP configuration file path.
419 ///
420 /// \param[in] cfg_path AIPP configuration file path.
421 inline void SetInsertOpConfigPath(const std::string &cfg_path);
422
423 /// \brief Get AIPP configuration file path.
424 ///
425 /// \return AIPP configuration file path.
426 inline std::string GetInsertOpConfigPath() const;
427
428 /// \brief Set format of model inputs.
429 ///
430 /// \param[in] format Optional "NCHW", "NHWC", and "ND".
431 inline void SetInputFormat(const std::string &format);
432
433 /// \brief Get format of model inputs.
434 ///
435 /// \return The format of model inputs.
436 inline std::string GetInputFormat() const;
437
438 /// \brief Set shape of model inputs.
439 ///
440 /// \param[in] shape e.g. "input_op_name1:1,2,3,4;input_op_name2:4,3,2,1".
441 inline void SetInputShape(const std::string &shape);
442
443 /// \brief Get shape of model inputs.
444 ///
445 /// \return The shape of model inputs.
446 inline std::string GetInputShape() const;
447
448 /// \brief Set shape of model inputs.
449 ///
450 /// \param[in] shape e.g. {{0, {1,2,3,4}}, {1, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
451 /// shape 4,3,2,1.
452 void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
453
454 /// \brief Get shape of model inputs.
455 ///
456 /// \return The shape of model inputs.
457 std::map<int, std::vector<int>> GetInputShapeMap() const;
458
459 /// \brief Set dynamic batch sizes of model inputs. Ranges from 2 to 100.
460 ///
461 /// \param[in] dynamic_batch_size e.g. {1, 2} means batch size 1 and 2 are configured.
462 void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
463
464 /// \brief Get dynamic batch sizes of model inputs.
465 ///
466 /// \return The dynamic batch sizes of model inputs in string format.
467 inline std::string GetDynamicBatchSize() const;
468
469 /// \brief Set the dynamic image size of model inputs.
470 ///
471 /// \param[in] dynamic_image_size size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64.
472 inline void SetDynamicImageSize(const std::string &dynamic_image_size);
473
474 /// \brief Get dynamic image size of model inputs.
475 ///
476 /// \return The image size of model inputs.
477 inline std::string GetDynamicImageSize() const;
478
479 /// \brief Set type of model outputs.
480 ///
481 /// \param[in] output_type FP32, UINT8 or FP16.
482 void SetOutputType(enum DataType output_type);
483
484 /// \brief Get type of model outputs.
485 ///
486 /// \return The set type of model outputs.
487 enum DataType GetOutputType() const;
488
489 /// \brief Set precision mode of model.
490 ///
491 /// \param[in] precision_mode Optional "enforce_fp16", "preferred_fp32", "enforce_origin", "enforce_fp32" and
492 /// "preferred_optimal", "enforce_fp16" is set as default
493 inline void SetPrecisionMode(const std::string &precision_mode);
494
495 /// \brief Get precision mode of model.
496 ///
497 /// \return The set type of model outputs
498 inline std::string GetPrecisionMode() const;
499
500 /// \brief Set op select implementation mode.
501 ///
502 /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
503 /// default.
504 inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
505
506 /// \brief Get op select implementation mode.
507 ///
508 /// \return The set op select implementation mode.
509 inline std::string GetOpSelectImplMode() const;
510
511 /// \brief Set fusion switch config file path. Controls which fusion passes to be turned off.
512 ///
513 /// \param[in] cfg_path fusion switch config file path.
514 inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
515
516 /// \brief Get fusion switch config file path.
517 ///
518 /// \return The fusion switch config file path.
519 inline std::string GetFusionSwitchConfigPath() const;
520
521 /// \brief Set buffer optimize mode.
522 ///
523 /// \param[in] buffer_optimize_mode Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize",
524 /// default as "l2_optimize".
525 inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
526
527 /// \brief Get buffer optimize mode.
528 ///
529 /// \return The buffer optimize mode.
530 inline std::string GetBufferOptimizeMode() const;
531
532 private:
533 void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
534 std::vector<char> GetInsertOpConfigPathChar() const;
535
536 void SetInputFormat(const std::vector<char> &format);
537 std::vector<char> GetInputFormatChar() const;
538
539 void SetInputShape(const std::vector<char> &shape);
540 std::vector<char> GetInputShapeChar() const;
541
542 std::vector<char> GetDynamicBatchSizeChar() const;
543
544 void SetDynamicImageSize(const std::vector<char> &dynamic_image_size);
545 std::vector<char> GetDynamicImageSizeChar() const;
546
547 void SetPrecisionMode(const std::vector<char> &precision_mode);
548 std::vector<char> GetPrecisionModeChar() const;
549
550 void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
551 std::vector<char> GetOpSelectImplModeChar() const;
552
553 void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
554 std::vector<char> GetFusionSwitchConfigPathChar() const;
555
556 void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
557 std::vector<char> GetBufferOptimizeModeChar() const;
558 };
559
560 using Ascend310DeviceInfo = AscendDeviceInfo;
561 using Ascend910DeviceInfo = AscendDeviceInfo;
562
SetInsertOpConfigPath(const std::string & cfg_path)563 void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
564 SetInsertOpConfigPath(StringToChar(cfg_path));
565 }
GetInsertOpConfigPath()566 std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
567
SetInputFormat(const std::string & format)568 void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
GetInputFormat()569 std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
570
SetInputShape(const std::string & shape)571 void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
GetInputShape()572 std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
573
GetDynamicBatchSize()574 std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
575
SetDynamicImageSize(const std::string & dynamic_image_size)576 void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
577 SetDynamicImageSize(StringToChar(dynamic_image_size));
578 }
579
GetDynamicImageSize()580 std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
581
SetPrecisionMode(const std::string & precision_mode)582 void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
583 SetPrecisionMode(StringToChar(precision_mode));
584 }
GetPrecisionMode()585 std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
586
SetOpSelectImplMode(const std::string & op_select_impl_mode)587 void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
588 SetOpSelectImplMode(StringToChar(op_select_impl_mode));
589 }
GetOpSelectImplMode()590 std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
591
SetFusionSwitchConfigPath(const std::string & cfg_path)592 void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
593 SetFusionSwitchConfigPath(StringToChar(cfg_path));
594 }
GetFusionSwitchConfigPath()595 std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
596 return CharToString(GetFusionSwitchConfigPathChar());
597 }
598
SetBufferOptimizeMode(const std::string & buffer_optimize_mode)599 void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
600 SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
601 }
GetBufferOptimizeMode()602 std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
603
604 struct Extension {
605 std::string name;
606 std::vector<uint8_t> value;
607 };
608
609 class MS_API NNRTDeviceInfo : public DeviceInfoContext {
610 public:
611 /// \brief Get the type of this DeviceInfoContext.
612 ///
613 /// \return Type of this DeviceInfoContext.
GetDeviceType()614 enum DeviceType GetDeviceType() const override { return DeviceType::kNNRt; };
615
616 /// \brief Set device id.
617 ///
618 /// \param[in] device_id The device id.
619 void SetDeviceID(size_t device_id);
620
621 /// \brief Get the device id.
622 ///
623 /// \return The device id.
624 size_t GetDeviceID() const;
625
626 /// \brief Set performance mode.
627 ///
628 /// \param[in] performance_mode The performance mode.
629 void SetPerformanceMode(int performance_mode);
630
631 /// \brief Get performance mode.
632 ///
633 /// \return The priority.
634 int GetPerformanceMode() const;
635
636 /// \brief Set priority.
637 ///
638 /// \param[in] priority The priority.
639 void SetPriority(int priority);
640
641 /// \brief Get priority.
642 ///
643 /// \return The priority.
644 int GetPriority() const;
645
646 /// \brief Set enables to perform the float16 inference
647 ///
648 /// \param[in] is_fp16 Enable float16 inference or not.
649 void SetEnableFP16(bool is_fp16);
650
651 /// \brief Get enables to perform the float16 inference
652 ///
653 /// \return Whether enable float16 inference.
654 bool GetEnableFP16() const;
655
656 /// \brief Set extensions
657 ///
658 /// \param[in] extension array.
659 void SetExtensions(const std::vector<Extension> &extensions);
660
661 /// \brief Get extensions
662 ///
663 /// \return extension array.
664 std::vector<Extension> GetExtensions() const;
665 };
666 } // namespace mindspore
667 #endif // MINDSPORE_INCLUDE_API_CONTEXT_H
668