1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Contains all the entry points to the C Neural Networks API.
18 // We do basic validation of the operands and then call the class
19 // that implements the functionality.
20
21 #define LOG_TAG "NeuralNetworks"
22
23 #include "NeuralNetworks.h"
24
25 #include "Callbacks.h"
26 #include "CompilationBuilder.h"
27 #include "ExecutionBuilder.h"
28 #include "Manager.h"
29 #include "Memory.h"
30 #include "NeuralNetworksOEM.h"
31 #include "ModelBuilder.h"
32
33 #include <memory>
34 #include <vector>
35
36 // Make sure the constants defined in the header files have not changed values.
37 // IMPORTANT: When adding new values, update kNumberOfDataTypes or kNumberOfDataTypesOEM
38 // in Utils.h.
39 static_assert(ANEURALNETWORKS_FLOAT32 == 0, "ANEURALNETWORKS_FLOAT32 has changed");
40 static_assert(ANEURALNETWORKS_INT32 == 1, "ANEURALNETWORKS_INT32 has changed");
41 static_assert(ANEURALNETWORKS_UINT32 == 2, "ANEURALNETWORKS_UINT32 has changed");
42 static_assert(ANEURALNETWORKS_TENSOR_FLOAT32 == 3,
43 "ANEURALNETWORKS_TENSOR_FLOAT32 has changed");
44 static_assert(ANEURALNETWORKS_TENSOR_INT32 == 4, "ANEURALNETWORKS_TENSOR_INT32 has changed");
45 static_assert(ANEURALNETWORKS_TENSOR_QUANT8_ASYMM == 5,
46 "ANEURALNETWORKS_TENSOR_QUANT8_ASYMM has changed");
47 static_assert(ANEURALNETWORKS_OEM_SCALAR == 10000, "ANEURALNETWORKS_OEM_SCALAR has changed");
48 static_assert(ANEURALNETWORKS_TENSOR_OEM_BYTE == 10001,
49 "ANEURALNETWORKS_TENSOR_OEM_BYTE has changed");
50
51 // IMPORTANT: When adding new values, update kNumberOfOperationTypes or
52 // kNumberOfOperationTypesOEMin Utils.h.
53 static_assert(ANEURALNETWORKS_ADD == 0, "ANEURALNETWORKS_ADD has changed");
54 static_assert(ANEURALNETWORKS_AVERAGE_POOL_2D == 1,
55 "ANEURALNETWORKS_AVERAGE_POOL_2D has changed");
56 static_assert(ANEURALNETWORKS_CONCATENATION == 2, "ANEURALNETWORKS_CONCATENATION has changed");
57 static_assert(ANEURALNETWORKS_CONV_2D == 3, "ANEURALNETWORKS_CONV_2D has changed");
58 static_assert(ANEURALNETWORKS_DEPTHWISE_CONV_2D == 4,
59 "ANEURALNETWORKS_DEPTHWISE_CONV_2D has changed");
60 static_assert(ANEURALNETWORKS_DEPTH_TO_SPACE == 5,
61 "ANEURALNETWORKS_DEPTH_TO_SPACE has changed");
62 static_assert(ANEURALNETWORKS_DEQUANTIZE == 6, "ANEURALNETWORKS_DEQUANTIZE has changed");
63 static_assert(ANEURALNETWORKS_EMBEDDING_LOOKUP == 7,
64 "ANEURALNETWORKS_EMBEDDING_LOOKUP has changed");
65 static_assert(ANEURALNETWORKS_FLOOR == 8, "ANEURALNETWORKS_FLOOR has changed");
66 static_assert(ANEURALNETWORKS_FULLY_CONNECTED == 9,
67 "ANEURALNETWORKS_FULLY_CONNECTED has changed");
68 static_assert(ANEURALNETWORKS_HASHTABLE_LOOKUP == 10,
69 "ANEURALNETWORKS_HASHTABLE_LOOKUP has changed");
70 static_assert(ANEURALNETWORKS_L2_NORMALIZATION == 11,
71 "ANEURALNETWORKS_L2_NORMALIZATION has changed");
72 static_assert(ANEURALNETWORKS_L2_POOL_2D == 12, "ANEURALNETWORKS_L2_POOL has changed");
73 static_assert(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION == 13,
74 "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION has changed");
75 static_assert(ANEURALNETWORKS_LOGISTIC == 14, "ANEURALNETWORKS_LOGISTIC has changed");
76 static_assert(ANEURALNETWORKS_LSH_PROJECTION == 15,
77 "ANEURALNETWORKS_LSH_PROJECTION has changed");
78 static_assert(ANEURALNETWORKS_LSTM == 16, "ANEURALNETWORKS_LSTM has changed");
79 static_assert(ANEURALNETWORKS_MAX_POOL_2D == 17, "ANEURALNETWORKS_MAX_POOL has changed");
80 static_assert(ANEURALNETWORKS_MUL == 18, "ANEURALNETWORKS_MUL has changed");
81 static_assert(ANEURALNETWORKS_RELU == 19, "ANEURALNETWORKS_RELU has changed");
82 static_assert(ANEURALNETWORKS_RELU1 == 20, "ANEURALNETWORKS_RELU1 has changed");
83 static_assert(ANEURALNETWORKS_RELU6 == 21, "ANEURALNETWORKS_RELU6 has changed");
84 static_assert(ANEURALNETWORKS_RESHAPE == 22, "ANEURALNETWORKS_RESHAPE has changed");
85 static_assert(ANEURALNETWORKS_RESIZE_BILINEAR == 23,
86 "ANEURALNETWORKS_RESIZE_BILINEAR has changed");
87 static_assert(ANEURALNETWORKS_RNN == 24, "ANEURALNETWORKS_RNN has changed");
88 static_assert(ANEURALNETWORKS_SOFTMAX == 25, "ANEURALNETWORKS_SOFTMAX has changed");
89 static_assert(ANEURALNETWORKS_SPACE_TO_DEPTH == 26,
90 "ANEURALNETWORKS_SPACE_TO_DEPTH has changed");
91 static_assert(ANEURALNETWORKS_SVDF == 27, "ANEURALNETWORKS_SVDF has changed");
92 static_assert(ANEURALNETWORKS_TANH == 28, "ANEURALNETWORKS_TANH has changed");
93 static_assert(ANEURALNETWORKS_OEM_OPERATION == 10000,
94 "ANEURALNETWORKS_OEM_OPERATION has changed");
95
96 static_assert(ANEURALNETWORKS_FUSED_NONE == 0, "ANEURALNETWORKS_FUSED_NONE has changed");
97 static_assert(ANEURALNETWORKS_FUSED_RELU == 1, "ANEURALNETWORKS_FUSED_RELU has changed");
98 static_assert(ANEURALNETWORKS_FUSED_RELU1 == 2, "ANEURALNETWORKS_FUSED_RELU1 has changed");
99 static_assert(ANEURALNETWORKS_FUSED_RELU6 == 3, "ANEURALNETWORKS_FUSED_RELU6 has changed");
100
101 static_assert(ANEURALNETWORKS_PREFER_LOW_POWER == 0,
102 "ANEURALNETWORKS_PREFER_LOW_POWER has changed");
103 static_assert(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER == 1,
104 "ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER has changed");
105 static_assert(ANEURALNETWORKS_PREFER_SUSTAINED_SPEED == 2,
106 "ANEURALNETWORKS_PREFER_SUSTAINED_SPEED has changed");
107
108 static_assert(ANEURALNETWORKS_NO_ERROR == 0, "ANEURALNETWORKS_NO_ERROR has changed");
109 static_assert(ANEURALNETWORKS_OUT_OF_MEMORY == 1, "ANEURALNETWORKS_OUT_OF_MEMORY has changed");
110 static_assert(ANEURALNETWORKS_INCOMPLETE == 2, "ANEURALNETWORKS_INCOMPLETE has changed");
111 static_assert(ANEURALNETWORKS_UNEXPECTED_NULL == 3,
112 "ANEURALNETWORKS_UNEXPECTED_NULL has changed");
113 static_assert(ANEURALNETWORKS_BAD_DATA == 4, "ANEURALNETWORKS_BAD_DATA has changed");
114 static_assert(ANEURALNETWORKS_OP_FAILED == 5, "ANEURALNETWORKS_OP_FAILED has changed");
115 static_assert(ANEURALNETWORKS_BAD_STATE == 6, "ANEURALNETWORKS_BAD_STATE has changed");
116
117 static_assert(ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES == 128,
118 "ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES has changed");
119
120 // Make sure that the constants are compatible with the values defined in
121 // hardware/interfaces/neuralnetworks/1.0/types.hal.
122 static_assert(static_cast<int32_t>(OperandType::OEM) == ANEURALNETWORKS_OEM_SCALAR,
123 "OEM != ANEURALNETWORKS_OEM");
124 static_assert(static_cast<int32_t>(OperandType::FLOAT32) == ANEURALNETWORKS_FLOAT32,
125 "FLOAT32 != ANEURALNETWORKS_FLOAT32");
126 static_assert(static_cast<int32_t>(OperandType::INT32) == ANEURALNETWORKS_INT32,
127 "INT32 != ANEURALNETWORKS_INT32");
128 static_assert(static_cast<int32_t>(OperandType::UINT32) == ANEURALNETWORKS_UINT32,
129 "UINT32 != ANEURALNETWORKS_UINT32");
130 static_assert(static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) == ANEURALNETWORKS_TENSOR_OEM_BYTE,
131 "TENSOR_OEM_BYTE != ANEURALNETWORKS_TENSOR_OEM_BYTE");
132 static_assert(static_cast<int32_t>(OperandType::TENSOR_FLOAT32) == ANEURALNETWORKS_TENSOR_FLOAT32,
133 "TENSOR_FLOAT32 != ANEURALNETWORKS_TENSOR_FLOAT32");
134 static_assert(static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) ==
135 ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
136 "TENSOR_QUANT8_ASYMM != ANEURALNETWORKS_TENSOR_QUANT8_ASYMM");
137
138 static_assert(static_cast<int32_t>(OperationType::ADD) == ANEURALNETWORKS_ADD,
139 "OperationType::ADD != ANEURALNETWORKS_ADD");
140 static_assert(static_cast<int32_t>(OperationType::AVERAGE_POOL_2D) ==
141 ANEURALNETWORKS_AVERAGE_POOL_2D,
142 "OperationType::AVERAGE_POOL_2D != ANEURALNETWORKS_AVERAGE_POOL_2D");
143 static_assert(static_cast<int32_t>(OperationType::CONV_2D) == ANEURALNETWORKS_CONV_2D,
144 "OperationType::CONV_2D != ANEURALNETWORKS_CONV_2D");
145 static_assert(static_cast<int32_t>(OperationType::DEPTHWISE_CONV_2D) ==
146 ANEURALNETWORKS_DEPTHWISE_CONV_2D,
147 "OperationType::DEPTHWISE_CONV_2D != ANEURALNETWORKS_DEPTHWISE_CONV_2D");
148 static_assert(static_cast<int32_t>(OperationType::DEPTH_TO_SPACE) ==
149 ANEURALNETWORKS_DEPTH_TO_SPACE,
150 "OperationType::DEPTH_TO_SPACE != ANEURALNETWORKS_DEPTH_TO_SPACE");
151 static_assert(static_cast<int32_t>(OperationType::DEQUANTIZE) == ANEURALNETWORKS_DEQUANTIZE,
152 "OperationType::DEQUANTIZE != ANEURALNETWORKS_DEQUANTIZE");
153 static_assert(static_cast<int32_t>(OperationType::EMBEDDING_LOOKUP) ==
154 ANEURALNETWORKS_EMBEDDING_LOOKUP,
155 "OperationType::EMBEDDING_LOOKUP != ANEURALNETWORKS_EMBEDDING_LOOKUP");
156 static_assert(static_cast<int32_t>(OperationType::FLOOR) == ANEURALNETWORKS_FLOOR,
157 "OperationType::FLOOR != ANEURALNETWORKS_FLOOR");
158 static_assert(static_cast<int32_t>(OperationType::FULLY_CONNECTED) ==
159 ANEURALNETWORKS_FULLY_CONNECTED,
160 "OperationType::FULLY_CONNECTED != ANEURALNETWORKS_FULLY_CONNECTED");
161 static_assert(static_cast<int32_t>(OperationType::HASHTABLE_LOOKUP) ==
162 ANEURALNETWORKS_HASHTABLE_LOOKUP,
163 "OperationType::HASHTABLE_LOOKUP != ANEURALNETWORKS_HASHTABLE_LOOKUP");
164 static_assert(static_cast<int32_t>(OperationType::L2_NORMALIZATION) ==
165 ANEURALNETWORKS_L2_NORMALIZATION,
166 "OperationType::L2_NORMALIZATION != ANEURALNETWORKS_L2_NORMALIZATION");
167 static_assert(static_cast<int32_t>(OperationType::L2_POOL_2D) == ANEURALNETWORKS_L2_POOL_2D,
168 "OperationType::L2_POOL_2D != ANEURALNETWORKS_L2_POOL_2D");
169 static_assert(static_cast<int32_t>(OperationType::LOCAL_RESPONSE_NORMALIZATION) ==
170 ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION,
171 "OperationType::LOCAL_RESPONSE_NORMALIZATION != "
172 "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION");
173 static_assert(static_cast<int32_t>(OperationType::LOGISTIC) == ANEURALNETWORKS_LOGISTIC,
174 "OperationType::LOGISTIC != ANEURALNETWORKS_LOGISTIC");
175 static_assert(static_cast<int32_t>(OperationType::LSH_PROJECTION) ==
176 ANEURALNETWORKS_LSH_PROJECTION,
177 "OperationType::LSH_PROJECTION != ANEURALNETWORKS_LSH_PROJECTION");
178 static_assert(static_cast<int32_t>(OperationType::LSTM) == ANEURALNETWORKS_LSTM,
179 "OperationType::LSTM != ANEURALNETWORKS_LSTM");
180 static_assert(static_cast<int32_t>(OperationType::MAX_POOL_2D) == ANEURALNETWORKS_MAX_POOL_2D,
181 "OperationType::MAX_POOL_2D != ANEURALNETWORKS_MAX_POOL_2D");
182 static_assert(static_cast<int32_t>(OperationType::MUL) == ANEURALNETWORKS_MUL,
183 "OperationType::MUL != ANEURALNETWORKS_MUL");
184 static_assert(static_cast<int32_t>(OperationType::RELU) == ANEURALNETWORKS_RELU,
185 "OperationType::RELU != ANEURALNETWORKS_RELU");
186 static_assert(static_cast<int32_t>(OperationType::RELU1) == ANEURALNETWORKS_RELU1,
187 "OperationType::RELU1 != ANEURALNETWORKS_RELU1");
188 static_assert(static_cast<int32_t>(OperationType::RELU6) == ANEURALNETWORKS_RELU6,
189 "OperationType::RELU6 != ANEURALNETWORKS_RELU6");
190 static_assert(static_cast<int32_t>(OperationType::RESHAPE) == ANEURALNETWORKS_RESHAPE,
191 "OperationType::RESHAPE != ANEURALNETWORKS_RESHAPE");
192 static_assert(static_cast<int32_t>(OperationType::RESIZE_BILINEAR) ==
193 ANEURALNETWORKS_RESIZE_BILINEAR,
194 "OperationType::RESIZE_BILINEAR != ANEURALNETWORKS_RESIZE_BILINEAR");
195 static_assert(static_cast<int32_t>(OperationType::RNN) == ANEURALNETWORKS_RNN,
196 "OperationType::RNN != ANEURALNETWORKS_RNN");
197 static_assert(static_cast<int32_t>(OperationType::SOFTMAX) == ANEURALNETWORKS_SOFTMAX,
198 "OperationType::SOFTMAX != ANEURALNETWORKS_SOFTMAX");
199 static_assert(static_cast<int32_t>(OperationType::SPACE_TO_DEPTH) ==
200 ANEURALNETWORKS_SPACE_TO_DEPTH,
201 "OperationType::SPACE_TO_DEPTH != ANEURALNETWORKS_SPACE_TO_DEPTH");
202 static_assert(static_cast<int32_t>(OperationType::SVDF) == ANEURALNETWORKS_SVDF,
203 "OperationType::SVDF != ANEURALNETWORKS_SVDF");
204 static_assert(static_cast<int32_t>(OperationType::TANH) == ANEURALNETWORKS_TANH,
205 "OperationType::TANH != ANEURALNETWORKS_TANH");
206
207 static_assert(static_cast<int32_t>(FusedActivationFunc::NONE) == ANEURALNETWORKS_FUSED_NONE,
208 "FusedActivationFunc::NONE != ANEURALNETWORKS_FUSED_NONE");
209 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU) == ANEURALNETWORKS_FUSED_RELU,
210 "FusedActivationFunc::RELU != ANEURALNETWORKS_FUSED_RELU");
211 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU1) == ANEURALNETWORKS_FUSED_RELU1,
212 "FusedActivationFunc::RELU1 != ANEURALNETWORKS_FUSED_RELU1");
213 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU6) == ANEURALNETWORKS_FUSED_RELU6,
214 "FusedActivationFunc::RELU6 != ANEURALNETWORKS_FUSED_RELU6");
215
216 using android::sp;
217 using namespace android::nn;
218
ANeuralNetworksMemory_createFromFd(size_t size,int prot,int fd,size_t offset,ANeuralNetworksMemory ** memory)219 int ANeuralNetworksMemory_createFromFd(size_t size, int prot, int fd, size_t offset,
220 ANeuralNetworksMemory** memory) {
221 *memory = nullptr;
222 std::unique_ptr<MemoryFd> m = std::make_unique<MemoryFd>();
223 if (m == nullptr) {
224 return ANEURALNETWORKS_OUT_OF_MEMORY;
225 }
226 int n = m->set(size, prot, fd, offset);
227 if (n != ANEURALNETWORKS_NO_ERROR) {
228 return n;
229 }
230 *memory = reinterpret_cast<ANeuralNetworksMemory*>(m.release());
231 return ANEURALNETWORKS_NO_ERROR;
232 }
233
ANeuralNetworksMemory_free(ANeuralNetworksMemory * memory)234 void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) {
235 // No validation. Free of nullptr is valid.
236 Memory* m = reinterpret_cast<Memory*>(memory);
237 delete m;
238 }
239
ANeuralNetworksModel_create(ANeuralNetworksModel ** model)240 int ANeuralNetworksModel_create(ANeuralNetworksModel** model) {
241 initVLogMask();
242 if (!model) {
243 LOG(ERROR) << "ANeuralNetworksModel_create passed a nullptr";
244 return ANEURALNETWORKS_UNEXPECTED_NULL;
245 }
246 ModelBuilder* m = new ModelBuilder();
247 if (m == nullptr) {
248 *model = nullptr;
249 return ANEURALNETWORKS_OUT_OF_MEMORY;
250 }
251 *model = reinterpret_cast<ANeuralNetworksModel*>(m);
252 return ANEURALNETWORKS_NO_ERROR;
253 }
254
ANeuralNetworksModel_free(ANeuralNetworksModel * model)255 void ANeuralNetworksModel_free(ANeuralNetworksModel* model) {
256 // No validation. Free of nullptr is valid.
257 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
258 delete m;
259 }
260
ANeuralNetworksModel_finish(ANeuralNetworksModel * model)261 int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
262 if (!model) {
263 LOG(ERROR) << "ANeuralNetworksModel_finish passed a nullptr";
264 return ANEURALNETWORKS_UNEXPECTED_NULL;
265 }
266 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
267 return m->finish();
268 }
269
ANeuralNetworksModel_addOperand(ANeuralNetworksModel * model,const ANeuralNetworksOperandType * type)270 int ANeuralNetworksModel_addOperand(ANeuralNetworksModel* model,
271 const ANeuralNetworksOperandType* type) {
272 if (!model || !type) {
273 LOG(ERROR) << "ANeuralNetworksModel_addOperand passed a nullptr";
274 return ANEURALNETWORKS_UNEXPECTED_NULL;
275 }
276 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
277 return m->addOperand(*type);
278 }
279
ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel * model,int32_t index,const void * buffer,size_t length)280 int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, int32_t index,
281 const void* buffer, size_t length) {
282 if (!model || !buffer) {
283 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
284 return ANEURALNETWORKS_UNEXPECTED_NULL;
285 }
286 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
287 return m->setOperandValue(index, buffer, length);
288 }
289
ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel * model,int32_t index,const ANeuralNetworksMemory * memory,size_t offset,size_t length)290 int ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index,
291 const ANeuralNetworksMemory* memory,
292 size_t offset, size_t length) {
293 if (!model || !memory) {
294 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
295 return ANEURALNETWORKS_UNEXPECTED_NULL;
296 }
297 const Memory* mem = reinterpret_cast<const Memory*>(memory);
298 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
299 return m->setOperandValueFromMemory(index, mem, offset, length);
300 }
301
ANeuralNetworksModel_addOperation(ANeuralNetworksModel * model,ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)302 int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
303 ANeuralNetworksOperationType type, uint32_t inputCount,
304 const uint32_t* inputs, uint32_t outputCount,
305 const uint32_t* outputs) {
306 if (!model || !inputs || !outputs) {
307 LOG(ERROR) << "ANeuralNetworksModel_addOperation passed a nullptr";
308 return ANEURALNETWORKS_UNEXPECTED_NULL;
309 }
310 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
311 return m->addOperation(type, inputCount, inputs, outputCount, outputs);
312 }
313
ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel * model,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)314 int ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount,
315 const uint32_t* inputs, uint32_t outputCount,
316 const uint32_t* outputs) {
317 if (!model || !inputs || !outputs) {
318 LOG(ERROR) << ("ANeuralNetworksModel_identifyInputsAndOutputs passed a nullptr");
319 return ANEURALNETWORKS_UNEXPECTED_NULL;
320 }
321 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
322 return m->identifyInputsAndOutputs(inputCount, inputs, outputCount, outputs);
323 }
324
ANeuralNetworksCompilation_create(ANeuralNetworksModel * model,ANeuralNetworksCompilation ** compilation)325 int ANeuralNetworksCompilation_create(ANeuralNetworksModel* model,
326 ANeuralNetworksCompilation** compilation) {
327 if (!model || !compilation) {
328 LOG(ERROR) << "ANeuralNetworksCompilation_create passed a nullptr";
329 return ANEURALNETWORKS_UNEXPECTED_NULL;
330 }
331
332 ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
333 CompilationBuilder* c = nullptr;
334 int result = m->createCompilation(&c);
335 *compilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
336 return result;
337 }
338
ANeuralNetworksCompilation_free(ANeuralNetworksCompilation * compilation)339 void ANeuralNetworksCompilation_free(ANeuralNetworksCompilation* compilation) {
340 // No validation. Free of nullptr is valid.
341 // TODO specification says that a compilation-in-flight can be deleted
342 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
343 delete c;
344 }
345
ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation * compilation,int32_t preference)346 int ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation* compilation,
347 int32_t preference) {
348 if (!compilation) {
349 LOG(ERROR) << "ANeuralNetworksCompilation_setPreference passed a nullptr";
350 return ANEURALNETWORKS_UNEXPECTED_NULL;
351 }
352 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
353 return c->setPreference(preference);
354 }
355
ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation * compilation)356 int ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation* compilation) {
357 if (!compilation) {
358 LOG(ERROR) << "ANeuralNetworksCompilation_finish passed a nullptr";
359 return ANEURALNETWORKS_UNEXPECTED_NULL;
360 }
361 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
362 return c->finish();
363 }
364
ANeuralNetworksExecution_create(ANeuralNetworksCompilation * compilation,ANeuralNetworksExecution ** execution)365 int ANeuralNetworksExecution_create(ANeuralNetworksCompilation* compilation,
366 ANeuralNetworksExecution** execution) {
367 if (!compilation || !execution) {
368 LOG(ERROR) << "ANeuralNetworksExecution_create passed a nullptr";
369 return ANEURALNETWORKS_UNEXPECTED_NULL;
370 }
371
372 CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
373 ExecutionBuilder* r = nullptr;
374 int result = c->createExecution(&r);
375 *execution = reinterpret_cast<ANeuralNetworksExecution*>(r);
376 return result;
377 }
378
ANeuralNetworksExecution_free(ANeuralNetworksExecution * execution)379 void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) {
380 // TODO specification says that an execution-in-flight can be deleted
381 // No validation. Free of nullptr is valid.
382 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
383 delete r;
384 }
385
ANeuralNetworksExecution_setInput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const void * buffer,size_t length)386 int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution* execution, int32_t index,
387 const ANeuralNetworksOperandType* type, const void* buffer,
388 size_t length) {
389 // TODO: For a non-optional input, also verify that buffer is not null.
390 if (!execution) {
391 LOG(ERROR) << "ANeuralNetworksExecution_setInput passed a nullptr";
392 return ANEURALNETWORKS_UNEXPECTED_NULL;
393 }
394 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
395 return r->setInput(index, type, buffer, length);
396 }
397
ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)398 int ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
399 const ANeuralNetworksOperandType* type,
400 const ANeuralNetworksMemory* memory, size_t offset,
401 size_t length) {
402 if (!execution || !memory) {
403 LOG(ERROR) << "ANeuralNetworksExecution_setInputFromMemory passed a nullptr";
404 return ANEURALNETWORKS_UNEXPECTED_NULL;
405 }
406
407 const Memory* m = reinterpret_cast<const Memory*>(memory);
408 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
409 return r->setInputFromMemory(index, type, m, offset, length);
410 }
411
ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,void * buffer,size_t length)412 int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution* execution, int32_t index,
413 const ANeuralNetworksOperandType* type, void* buffer,
414 size_t length) {
415 if (!execution || !buffer) {
416 LOG(ERROR) << "ANeuralNetworksExecution_setOutput passed a nullptr";
417 return ANEURALNETWORKS_UNEXPECTED_NULL;
418 }
419 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
420 return r->setOutput(index, type, buffer, length);
421 }
422
ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)423 int ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
424 const ANeuralNetworksOperandType* type,
425 const ANeuralNetworksMemory* memory, size_t offset,
426 size_t length) {
427 if (!execution || !memory) {
428 LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory passed a nullptr";
429 return ANEURALNETWORKS_UNEXPECTED_NULL;
430 }
431
432 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
433 const Memory* m = reinterpret_cast<const Memory*>(memory);
434 return r->setOutputFromMemory(index, type, m, offset, length);
435 }
436
ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution * execution,ANeuralNetworksEvent ** event)437 int ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution* execution,
438 ANeuralNetworksEvent** event) {
439 if (!execution || !event) {
440 LOG(ERROR) << "ANeuralNetworksExecution_startCompute passed a nullptr";
441 return ANEURALNETWORKS_UNEXPECTED_NULL;
442 }
443 // TODO validate the rest
444
445 ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
446
447 // Dynamically allocate an sp to wrap an ExecutionCallback, seen in the NN
448 // API as an abstract event object. The sp<ExecutionCallback> object is
449 // returned when the execution has been successfully launched, otherwise a
450 // nullptr is returned. The sp is used for ref-counting purposes. Without
451 // it, the HIDL service could attempt to communicate with a dead callback
452 // object.
453 std::unique_ptr<sp<ExecutionCallback>> e = std::make_unique<sp<ExecutionCallback>>();
454 *event = nullptr;
455
456 int n = r->startCompute(e.get());
457 if (n != ANEURALNETWORKS_NO_ERROR) {
458 return n;
459 }
460 *event = reinterpret_cast<ANeuralNetworksEvent*>(e.release());
461 return ANEURALNETWORKS_NO_ERROR;
462 }
463
ANeuralNetworksEvent_wait(ANeuralNetworksEvent * event)464 int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) {
465 if (event == nullptr) {
466 LOG(ERROR) << "ANeuralNetworksEvent_wait passed a nullptr";
467 return ANEURALNETWORKS_UNEXPECTED_NULL;
468 }
469
470 sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
471 (*e)->wait();
472 return ANEURALNETWORKS_NO_ERROR;
473 }
474
ANeuralNetworksEvent_free(ANeuralNetworksEvent * event)475 void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
476 // No validation. Free of nullptr is valid.
477 if (event) {
478 sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
479 (*e)->wait();
480 delete e;
481 }
482 }
483