1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/c_api/context_c.h"
17 #include "src/c_api/context_c.h"
18 #include "src/common/log_adapter.h"
19
20 // ================ Context ================
OH_AI_ContextCreate()21 OH_AI_ContextHandle OH_AI_ContextCreate() {
22 auto impl = new (std::nothrow) mindspore::ContextC;
23 if (impl == nullptr) {
24 MS_LOG(ERROR) << "memory allocation failed.";
25 return nullptr;
26 }
27 return static_cast<OH_AI_ContextHandle>(impl);
28 }
29
OH_AI_ContextDestroy(OH_AI_ContextHandle * context)30 void OH_AI_ContextDestroy(OH_AI_ContextHandle *context) {
31 if (context != nullptr && *context != nullptr) {
32 auto impl = static_cast<mindspore::ContextC *>(*context);
33 delete impl;
34 *context = nullptr;
35 }
36 }
37
OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context,int32_t thread_num)38 void OH_AI_ContextSetThreadNum(OH_AI_ContextHandle context, int32_t thread_num) {
39 if (context == nullptr) {
40 MS_LOG(ERROR) << "param is nullptr.";
41 return;
42 }
43 auto impl = static_cast<mindspore::ContextC *>(context);
44 impl->thread_num = thread_num;
45 }
46
OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context)47 int32_t OH_AI_ContextGetThreadNum(const OH_AI_ContextHandle context) {
48 if (context == nullptr) {
49 MS_LOG(ERROR) << "param is nullptr.";
50 return 0;
51 }
52 auto impl = static_cast<mindspore::ContextC *>(context);
53 return impl->thread_num;
54 }
55
OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context,int mode)56 void OH_AI_ContextSetThreadAffinityMode(OH_AI_ContextHandle context, int mode) {
57 if (context == nullptr) {
58 MS_LOG(ERROR) << "param is nullptr.";
59 return;
60 }
61 auto impl = static_cast<mindspore::ContextC *>(context);
62 impl->affinity_mode = mode;
63 return;
64 }
65
OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context)66 int OH_AI_ContextGetThreadAffinityMode(const OH_AI_ContextHandle context) {
67 if (context == nullptr) {
68 MS_LOG(ERROR) << "param is nullptr.";
69 return 0;
70 }
71 auto impl = static_cast<mindspore::ContextC *>(context);
72 return impl->affinity_mode;
73 }
74
OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context,const int32_t * core_list,size_t core_num)75 void OH_AI_ContextSetThreadAffinityCoreList(OH_AI_ContextHandle context, const int32_t *core_list, size_t core_num) {
76 if (context == nullptr || core_list == nullptr) {
77 MS_LOG(ERROR) << "param is nullptr.";
78 return;
79 }
80 const std::vector<int32_t> vec_core_list(core_list, core_list + core_num);
81 auto impl = static_cast<mindspore::ContextC *>(context);
82 impl->affinity_core_list = vec_core_list;
83 return;
84 }
85
OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle context,size_t * core_num)86 const int32_t *OH_AI_ContextGetThreadAffinityCoreList(const OH_AI_ContextHandle context, size_t *core_num) {
87 if (context == nullptr || core_num == nullptr) {
88 MS_LOG(ERROR) << "param is nullptr.";
89 return nullptr;
90 }
91 auto impl = static_cast<mindspore::ContextC *>(context);
92 *core_num = impl->affinity_core_list.size();
93 return impl->affinity_core_list.data();
94 }
95
OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context,bool is_parallel)96 void OH_AI_ContextSetEnableParallel(OH_AI_ContextHandle context, bool is_parallel) {
97 if (context == nullptr) {
98 MS_LOG(ERROR) << "param is nullptr.";
99 return;
100 }
101 auto impl = static_cast<mindspore::ContextC *>(context);
102 impl->enable_parallel = is_parallel;
103 }
104
OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context)105 bool OH_AI_ContextGetEnableParallel(const OH_AI_ContextHandle context) {
106 if (context == nullptr) {
107 MS_LOG(ERROR) << "param is nullptr.";
108 return false;
109 }
110 auto impl = static_cast<mindspore::ContextC *>(context);
111 return impl->enable_parallel;
112 }
113
OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context,OH_AI_DeviceInfoHandle device_info)114 void OH_AI_ContextAddDeviceInfo(OH_AI_ContextHandle context, OH_AI_DeviceInfoHandle device_info) {
115 if (context == nullptr || device_info == nullptr) {
116 MS_LOG(ERROR) << "param is nullptr.";
117 return;
118 }
119 auto impl = static_cast<mindspore::ContextC *>(context);
120 std::shared_ptr<mindspore::DeviceInfoC> device(static_cast<mindspore::DeviceInfoC *>(device_info));
121 impl->device_info_list.push_back(device);
122 }
123
124 // ================ DeviceInfo ================
OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type)125 OH_AI_DeviceInfoHandle OH_AI_DeviceInfoCreate(OH_AI_DeviceType device_type) {
126 mindspore::DeviceInfoC *impl = new (std::nothrow) mindspore::DeviceInfoC;
127 if (impl == nullptr) {
128 MS_LOG(ERROR) << "memory allocation failed.";
129 return nullptr;
130 }
131 impl->device_type = device_type;
132 return static_cast<OH_AI_DeviceInfoHandle>(impl);
133 }
134
OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle * device_info)135 void OH_AI_DeviceInfoDestroy(OH_AI_DeviceInfoHandle *device_info) {
136 if (device_info != nullptr && *device_info != nullptr) {
137 auto impl = static_cast<mindspore::DeviceInfoC *>(*device_info);
138 delete impl;
139 *device_info = nullptr;
140 }
141 }
142
OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info,const char * provider)143 void OH_AI_DeviceInfoSetProvider(OH_AI_DeviceInfoHandle device_info, const char *provider) {
144 if (device_info == nullptr) {
145 MS_LOG(ERROR) << "param is nullptr.";
146 return;
147 }
148 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
149 impl->provider = provider;
150 }
151
OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info)152 const char *OH_AI_DeviceInfoGetProvider(const OH_AI_DeviceInfoHandle device_info) {
153 if (device_info == nullptr) {
154 MS_LOG(ERROR) << "param is nullptr.";
155 return nullptr;
156 }
157 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
158 return impl->provider.c_str();
159 }
160
OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info,const char * device)161 void OH_AI_DeviceInfoSetProviderDevice(OH_AI_DeviceInfoHandle device_info, const char *device) {
162 if (device_info == nullptr) {
163 MS_LOG(ERROR) << "param is nullptr.";
164 return;
165 }
166 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
167 impl->provider_device = device;
168 }
169
OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info)170 const char *OH_AI_DeviceInfoGetProviderDevice(const OH_AI_DeviceInfoHandle device_info) {
171 if (device_info == nullptr) {
172 MS_LOG(ERROR) << "param is nullptr.";
173 return nullptr;
174 }
175 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
176 return impl->provider_device.c_str();
177 }
178
OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info)179 OH_AI_DeviceType OH_AI_DeviceInfoGetDeviceType(const OH_AI_DeviceInfoHandle device_info) {
180 if (device_info == nullptr) {
181 MS_LOG(ERROR) << "param is nullptr.";
182 return OH_AI_DEVICETYPE_INVALID;
183 }
184 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
185 return impl->device_type;
186 }
187
OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info,bool is_fp16)188 void OH_AI_DeviceInfoSetEnableFP16(OH_AI_DeviceInfoHandle device_info, bool is_fp16) {
189 if (device_info == nullptr) {
190 MS_LOG(ERROR) << "param is nullptr.";
191 return;
192 }
193 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
194 if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
195 impl->enable_fp16 = is_fp16;
196 } else {
197 MS_LOG(ERROR) << "Unsupported Feature.";
198 }
199 }
200
OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info)201 bool OH_AI_DeviceInfoGetEnableFP16(const OH_AI_DeviceInfoHandle device_info) {
202 if (device_info == nullptr) {
203 MS_LOG(ERROR) << "param is nullptr.";
204 return false;
205 }
206 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
207 if (impl->device_type == OH_AI_DEVICETYPE_CPU || impl->device_type == OH_AI_DEVICETYPE_GPU) {
208 return impl->enable_fp16;
209 } else {
210 MS_LOG(ERROR) << "Unsupported Feature. device_type: " << impl->device_type;
211 return false;
212 }
213 }
214
OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info,int frequency)215 void OH_AI_DeviceInfoSetFrequency(OH_AI_DeviceInfoHandle device_info, int frequency) { // only for KirinNPU
216 if (device_info == nullptr) {
217 MS_LOG(ERROR) << "param is nullptr.";
218 return;
219 }
220 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
221 if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
222 impl->frequency = frequency;
223 } else {
224 MS_LOG(ERROR) << "Unsupported Feature.";
225 }
226 }
227
OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info)228 int OH_AI_DeviceInfoGetFrequency(const OH_AI_DeviceInfoHandle device_info) { // only for KirinNPU
229 if (device_info == nullptr) {
230 MS_LOG(ERROR) << "param is nullptr.";
231 return -1;
232 }
233 auto impl = static_cast<mindspore::DeviceInfoC *>(device_info);
234 if (impl->device_type == OH_AI_DEVICETYPE_KIRIN_NPU) {
235 return impl->frequency;
236 } else {
237 MS_LOG(ERROR) << "Unsupported Feature.";
238 return -1;
239 }
240 }
241