• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/c_api/context_c.h"
17 #include "src/c_api/context_c.h"
18 #include "src/common/log_adapter.h"
19 
20 // ================ Context ================
OH_AI_ContextCreate()21 OH_AI_ContextHandle OH_AI_ContextCreate() {
22   auto impl = new (std::nothrow) mindspore::ContextC;
23   if (impl == nullptr) {
24     MS_LOG(ERROR) << "memory allocation failed.";
25     return nullptr;
26   }
27   return static_cast<OH_AI_ContextHandle>(impl);
28 }
29 
OH_AI_ContextDestroy(OH_AI_ContextHandle * context)30 void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) {
31   if (context != nullptr && *context != nullptr) {
32     auto impl = static_cast<mindspore::ContextC *>(*context);
33     delete impl;
34     *context = nullptr;
35   }
36 }
37 
OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context,int32_t thread_num)38 void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) {
39   if (context == nullptr) {
40     MS_LOG(ERROR) << "param is nullptr.";
41     return;
42   }
43   auto impl = static_cast<mindspore::ContextC *>(context);
44   impl->thread_num = thread_num;
45 }
46 
OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context)47 int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
48   if (context == nullptr) {
49     MS_LOG(ERROR) << "param is nullptr.";
50     return 0;
51   }
52   auto impl = static_cast<mindspore::ContextC *>(context);
53   return impl->thread_num;
54 }
55 
OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context,int mode)56 void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
57   if (context == nullptr) {
58     MS_LOG(ERROR) << "param is nullptr.";
59     return;
60   }
61   auto impl = static_cast<mindspore::ContextC *>(context);
62   impl->affinity_mode = mode;
63   return;
64 }
65 
OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context)66 int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) {
67   if (context == nullptr) {
68     MS_LOG(ERROR) << "param is nullptr.";
69     return 0;
70   }
71   auto impl = static_cast<mindspore::ContextC *>(context);
72   return impl->affinity_mode;
73 }
74 
OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context,const int32_t * core_list,size_t core_num)75 void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) {
76   if (context == nullptr || core_list == nullptr) {
77     MS_LOG(ERROR) << "param is nullptr.";
78     return;
79   }
80   const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
81   auto impl = static_cast<mindspore::ContextC *>(context);
82   impl->affinity_core_list = vec_core_list;
83   return;
84 }
85 
OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle context,size_t * core_num)86 const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle context, size_t *core_num) {
87   if (context == nullptr || core_num == nullptr) {
88     MS_LOG(ERROR) << "param is nullptr.";
89     return nullptr;
90   }
91   auto impl = static_cast<mindspore::ContextC *>(context);
92   *core_num = impl->affinity_core_list.size();
93   return impl->affinity_core_list.data();
94 }
95 
OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context,bool is_parallel)96 void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_parallel) {
97   if (context == nullptr) {
98     MS_LOG(ERROR) << "param is nullptr.";
99     return;
100   }
101   auto impl = static_cast<mindspore::ContextC *>(context);
102   impl->enable_parallel = is_parallel;
103 }
104 
OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context)105 bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
106   if (context == nullptr) {
107     MS_LOG(ERROR) << "param is nullptr.";
108     return false;
109   }
110   auto impl = static_cast<mindspore::ContextC *>(context);
111   return impl->enable_parallel;
112 }
113 
OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context,OH_AI_DeviceInfoHandle device_info)114 void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) {
115   if (context == nullptr || device_info == nullptr) {
116     MS_LOG(ERROR) << "param is nullptr.";
117     return;
118   }
119   auto impl = static_cast<mindspore::ContextC *>(context);
120   std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info));
121   impl->device_info_list.push_back(device);
122 }
123 
124 // ================ DeviceInfo ================
OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type)125 OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type) {
126   mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC;
127   if (impl == nullptr) {
128     MS_LOG(ERROR) << "memory allocation failed.";
129     return nullptr;
130   }
131   impl->device_type = device_type;
132   return static_cast<OH_AI_DeviceInfoHandle>(impl);
133 }
134 
OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle * device_info)135 void OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle *device_info) {
136   if (device_info != nullptr && *device_info != nullptr) {
137     auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info);
138     delete impl;
139     *device_info = nullptr;
140   }
141 }
142 
OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info,const char * provider)143 void OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info, const char *provider) {
144   if (device_info == nullptr) {
145     MS_LOG(ERROR) << "param is nullptr.";
146     return;
147   }
148   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
149   impl->provider = provider;
150 }
151 
OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info)152 const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info) {
153   if (device_info == nullptr) {
154     MS_LOG(ERROR) << "param is nullptr.";
155     return nullptr;
156   }
157   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
158   return impl->provider.c_str();
159 }
160 
OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info,const char * device)161 void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const char *device) {
162   if (device_info == nullptr) {
163     MS_LOG(ERROR) << "param is nullptr.";
164     return;
165   }
166   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
167   impl->provider_device = device;
168 }
169 
OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info)170 const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info) {
171   if (device_info == nullptr) {
172     MS_LOG(ERROR) << "param is nullptr.";
173     return nullptr;
174   }
175   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
176   return impl->provider_device.c_str();
177 }
178 
OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info)179 OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info) {
180   if (device_info == nullptr) {
181     MS_LOG(ERROR) << "param is nullptr.";
182     return OH_AI_DEVICETYPE_INVALID;
183   }
184   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
185   return impl->device_type;
186 }
187 
OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info,bool is_fp16)188 void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_fp16) {
189   if (device_info == nullptr) {
190     MS_LOG(ERROR) << "param is nullptr.";
191     return;
192   }
193   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
194   if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
195     impl->enable_fp16 = is_fp16;
196   } else {
197     MS_LOG(ERROR) << "Unsupported Feature.";
198   }
199 }
200 
OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info)201 bool OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info) {
202   if (device_info == nullptr) {
203     MS_LOG(ERROR) << "param is nullptr.";
204     return false;
205   }
206   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
207   if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
208     return impl->enable_fp16;
209   } else {
210     MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type;
211     return false;
212   }
213 }
214 
OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info,int frequency)215 void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, int frequency) {  // only for KirinNPU
216   if (device_info == nullptr) {
217     MS_LOG(ERROR) << "param is nullptr.";
218     return;
219   }
220   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
221   if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
222     impl->frequency = frequency;
223   } else {
224     MS_LOG(ERROR) << "Unsupported Feature.";
225   }
226 }
227 
OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info)228 int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info) {  // only for KirinNPU
229   if (device_info == nullptr) {
230     MS_LOG(ERROR) << "param is nullptr.";
231     return -1;
232   }
233   auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
234   if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
235     return impl->frequency;
236   } else {
237     MS_LOG(ERROR) << "Unsupported Feature.";
238     return -1;
239   }
240 }
241