• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains all the entry points to the C Neural Networks API.
18 // We do basic validation of the operands and then call the class
19 // that implements the functionality.
20 
21 #define LOG_TAG "NeuralNetworks"
22 
23 #include "NeuralNetworks.h"
24 
25 #include "Callbacks.h"
26 #include "CompilationBuilder.h"
27 #include "ExecutionBuilder.h"
28 #include "Manager.h"
29 #include "Memory.h"
30 #include "NeuralNetworksOEM.h"
31 #include "ModelBuilder.h"
32 
33 #include <memory>
34 #include <vector>
35 
36 // Make sure the constants defined in the header files have not changed values.
37 // IMPORTANT: When adding new values, update kNumberOfDataTypes or kNumberOfDataTypesOEM
38 // in Utils.h.
39 static_assert(ANEURALNETWORKS_FLOAT32 == 0, "ANEURALNETWORKS_FLOAT32 has changed");
40 static_assert(ANEURALNETWORKS_INT32 == 1, "ANEURALNETWORKS_INT32 has changed");
41 static_assert(ANEURALNETWORKS_UINT32 == 2, "ANEURALNETWORKS_UINT32 has changed");
42 static_assert(ANEURALNETWORKS_TENSOR_FLOAT32 == 3,
43               "ANEURALNETWORKS_TENSOR_FLOAT32 has changed");
44 static_assert(ANEURALNETWORKS_TENSOR_INT32 == 4, "ANEURALNETWORKS_TENSOR_INT32 has changed");
45 static_assert(ANEURALNETWORKS_TENSOR_QUANT8_ASYMM == 5,
46               "ANEURALNETWORKS_TENSOR_QUANT8_ASYMM has changed");
47 static_assert(ANEURALNETWORKS_OEM_SCALAR == 10000, "ANEURALNETWORKS_OEM_SCALAR has changed");
48 static_assert(ANEURALNETWORKS_TENSOR_OEM_BYTE == 10001,
49               "ANEURALNETWORKS_TENSOR_OEM_BYTE has changed");
50 
51 // IMPORTANT: When adding new values, update kNumberOfOperationTypes or
52 // kNumberOfOperationTypesOEMin Utils.h.
53 static_assert(ANEURALNETWORKS_ADD == 0, "ANEURALNETWORKS_ADD has changed");
54 static_assert(ANEURALNETWORKS_AVERAGE_POOL_2D == 1,
55               "ANEURALNETWORKS_AVERAGE_POOL_2D has changed");
56 static_assert(ANEURALNETWORKS_CONCATENATION == 2, "ANEURALNETWORKS_CONCATENATION has changed");
57 static_assert(ANEURALNETWORKS_CONV_2D == 3, "ANEURALNETWORKS_CONV_2D has changed");
58 static_assert(ANEURALNETWORKS_DEPTHWISE_CONV_2D == 4,
59               "ANEURALNETWORKS_DEPTHWISE_CONV_2D has changed");
60 static_assert(ANEURALNETWORKS_DEPTH_TO_SPACE == 5,
61               "ANEURALNETWORKS_DEPTH_TO_SPACE has changed");
62 static_assert(ANEURALNETWORKS_DEQUANTIZE == 6, "ANEURALNETWORKS_DEQUANTIZE has changed");
63 static_assert(ANEURALNETWORKS_EMBEDDING_LOOKUP == 7,
64               "ANEURALNETWORKS_EMBEDDING_LOOKUP has changed");
65 static_assert(ANEURALNETWORKS_FLOOR == 8, "ANEURALNETWORKS_FLOOR has changed");
66 static_assert(ANEURALNETWORKS_FULLY_CONNECTED == 9,
67               "ANEURALNETWORKS_FULLY_CONNECTED has changed");
68 static_assert(ANEURALNETWORKS_HASHTABLE_LOOKUP == 10,
69               "ANEURALNETWORKS_HASHTABLE_LOOKUP has changed");
70 static_assert(ANEURALNETWORKS_L2_NORMALIZATION == 11,
71               "ANEURALNETWORKS_L2_NORMALIZATION has changed");
72 static_assert(ANEURALNETWORKS_L2_POOL_2D == 12, "ANEURALNETWORKS_L2_POOL has changed");
73 static_assert(ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION == 13,
74               "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION has changed");
75 static_assert(ANEURALNETWORKS_LOGISTIC == 14, "ANEURALNETWORKS_LOGISTIC has changed");
76 static_assert(ANEURALNETWORKS_LSH_PROJECTION == 15,
77               "ANEURALNETWORKS_LSH_PROJECTION has changed");
78 static_assert(ANEURALNETWORKS_LSTM == 16, "ANEURALNETWORKS_LSTM has changed");
79 static_assert(ANEURALNETWORKS_MAX_POOL_2D == 17, "ANEURALNETWORKS_MAX_POOL has changed");
80 static_assert(ANEURALNETWORKS_MUL == 18, "ANEURALNETWORKS_MUL has changed");
81 static_assert(ANEURALNETWORKS_RELU == 19, "ANEURALNETWORKS_RELU has changed");
82 static_assert(ANEURALNETWORKS_RELU1 == 20, "ANEURALNETWORKS_RELU1 has changed");
83 static_assert(ANEURALNETWORKS_RELU6 == 21, "ANEURALNETWORKS_RELU6 has changed");
84 static_assert(ANEURALNETWORKS_RESHAPE == 22, "ANEURALNETWORKS_RESHAPE has changed");
85 static_assert(ANEURALNETWORKS_RESIZE_BILINEAR == 23,
86               "ANEURALNETWORKS_RESIZE_BILINEAR has changed");
87 static_assert(ANEURALNETWORKS_RNN == 24, "ANEURALNETWORKS_RNN has changed");
88 static_assert(ANEURALNETWORKS_SOFTMAX == 25, "ANEURALNETWORKS_SOFTMAX has changed");
89 static_assert(ANEURALNETWORKS_SPACE_TO_DEPTH == 26,
90               "ANEURALNETWORKS_SPACE_TO_DEPTH has changed");
91 static_assert(ANEURALNETWORKS_SVDF == 27, "ANEURALNETWORKS_SVDF has changed");
92 static_assert(ANEURALNETWORKS_TANH == 28, "ANEURALNETWORKS_TANH has changed");
93 static_assert(ANEURALNETWORKS_OEM_OPERATION == 10000,
94               "ANEURALNETWORKS_OEM_OPERATION has changed");
95 
96 static_assert(ANEURALNETWORKS_FUSED_NONE == 0, "ANEURALNETWORKS_FUSED_NONE has changed");
97 static_assert(ANEURALNETWORKS_FUSED_RELU == 1, "ANEURALNETWORKS_FUSED_RELU has changed");
98 static_assert(ANEURALNETWORKS_FUSED_RELU1 == 2, "ANEURALNETWORKS_FUSED_RELU1 has changed");
99 static_assert(ANEURALNETWORKS_FUSED_RELU6 == 3, "ANEURALNETWORKS_FUSED_RELU6 has changed");
100 
101 static_assert(ANEURALNETWORKS_PREFER_LOW_POWER == 0,
102               "ANEURALNETWORKS_PREFER_LOW_POWER has changed");
103 static_assert(ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER == 1,
104               "ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER has changed");
105 static_assert(ANEURALNETWORKS_PREFER_SUSTAINED_SPEED == 2,
106               "ANEURALNETWORKS_PREFER_SUSTAINED_SPEED has changed");
107 
108 static_assert(ANEURALNETWORKS_NO_ERROR == 0, "ANEURALNETWORKS_NO_ERROR has changed");
109 static_assert(ANEURALNETWORKS_OUT_OF_MEMORY == 1, "ANEURALNETWORKS_OUT_OF_MEMORY has changed");
110 static_assert(ANEURALNETWORKS_INCOMPLETE == 2, "ANEURALNETWORKS_INCOMPLETE has changed");
111 static_assert(ANEURALNETWORKS_UNEXPECTED_NULL == 3,
112               "ANEURALNETWORKS_UNEXPECTED_NULL has changed");
113 static_assert(ANEURALNETWORKS_BAD_DATA == 4, "ANEURALNETWORKS_BAD_DATA has changed");
114 static_assert(ANEURALNETWORKS_OP_FAILED == 5, "ANEURALNETWORKS_OP_FAILED has changed");
115 static_assert(ANEURALNETWORKS_BAD_STATE == 6, "ANEURALNETWORKS_BAD_STATE has changed");
116 
117 static_assert(ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES == 128,
118               "ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES has changed");
119 
120 // Make sure that the constants are compatible with the values defined in
121 // hardware/interfaces/neuralnetworks/1.0/types.hal.
122 static_assert(static_cast<int32_t>(OperandType::OEM) == ANEURALNETWORKS_OEM_SCALAR,
123               "OEM != ANEURALNETWORKS_OEM");
124 static_assert(static_cast<int32_t>(OperandType::FLOAT32) == ANEURALNETWORKS_FLOAT32,
125               "FLOAT32 != ANEURALNETWORKS_FLOAT32");
126 static_assert(static_cast<int32_t>(OperandType::INT32) == ANEURALNETWORKS_INT32,
127               "INT32 != ANEURALNETWORKS_INT32");
128 static_assert(static_cast<int32_t>(OperandType::UINT32) == ANEURALNETWORKS_UINT32,
129               "UINT32 != ANEURALNETWORKS_UINT32");
130 static_assert(static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) == ANEURALNETWORKS_TENSOR_OEM_BYTE,
131               "TENSOR_OEM_BYTE != ANEURALNETWORKS_TENSOR_OEM_BYTE");
132 static_assert(static_cast<int32_t>(OperandType::TENSOR_FLOAT32) == ANEURALNETWORKS_TENSOR_FLOAT32,
133               "TENSOR_FLOAT32 != ANEURALNETWORKS_TENSOR_FLOAT32");
134 static_assert(static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) ==
135                           ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
136               "TENSOR_QUANT8_ASYMM != ANEURALNETWORKS_TENSOR_QUANT8_ASYMM");
137 
138 static_assert(static_cast<int32_t>(OperationType::ADD) == ANEURALNETWORKS_ADD,
139               "OperationType::ADD != ANEURALNETWORKS_ADD");
140 static_assert(static_cast<int32_t>(OperationType::AVERAGE_POOL_2D) ==
141                           ANEURALNETWORKS_AVERAGE_POOL_2D,
142               "OperationType::AVERAGE_POOL_2D != ANEURALNETWORKS_AVERAGE_POOL_2D");
143 static_assert(static_cast<int32_t>(OperationType::CONV_2D) == ANEURALNETWORKS_CONV_2D,
144               "OperationType::CONV_2D != ANEURALNETWORKS_CONV_2D");
145 static_assert(static_cast<int32_t>(OperationType::DEPTHWISE_CONV_2D) ==
146                           ANEURALNETWORKS_DEPTHWISE_CONV_2D,
147               "OperationType::DEPTHWISE_CONV_2D != ANEURALNETWORKS_DEPTHWISE_CONV_2D");
148 static_assert(static_cast<int32_t>(OperationType::DEPTH_TO_SPACE) ==
149                           ANEURALNETWORKS_DEPTH_TO_SPACE,
150               "OperationType::DEPTH_TO_SPACE != ANEURALNETWORKS_DEPTH_TO_SPACE");
151 static_assert(static_cast<int32_t>(OperationType::DEQUANTIZE) == ANEURALNETWORKS_DEQUANTIZE,
152               "OperationType::DEQUANTIZE != ANEURALNETWORKS_DEQUANTIZE");
153 static_assert(static_cast<int32_t>(OperationType::EMBEDDING_LOOKUP) ==
154                           ANEURALNETWORKS_EMBEDDING_LOOKUP,
155               "OperationType::EMBEDDING_LOOKUP != ANEURALNETWORKS_EMBEDDING_LOOKUP");
156 static_assert(static_cast<int32_t>(OperationType::FLOOR) == ANEURALNETWORKS_FLOOR,
157               "OperationType::FLOOR != ANEURALNETWORKS_FLOOR");
158 static_assert(static_cast<int32_t>(OperationType::FULLY_CONNECTED) ==
159                           ANEURALNETWORKS_FULLY_CONNECTED,
160               "OperationType::FULLY_CONNECTED != ANEURALNETWORKS_FULLY_CONNECTED");
161 static_assert(static_cast<int32_t>(OperationType::HASHTABLE_LOOKUP) ==
162                           ANEURALNETWORKS_HASHTABLE_LOOKUP,
163               "OperationType::HASHTABLE_LOOKUP != ANEURALNETWORKS_HASHTABLE_LOOKUP");
164 static_assert(static_cast<int32_t>(OperationType::L2_NORMALIZATION) ==
165                           ANEURALNETWORKS_L2_NORMALIZATION,
166               "OperationType::L2_NORMALIZATION != ANEURALNETWORKS_L2_NORMALIZATION");
167 static_assert(static_cast<int32_t>(OperationType::L2_POOL_2D) == ANEURALNETWORKS_L2_POOL_2D,
168               "OperationType::L2_POOL_2D != ANEURALNETWORKS_L2_POOL_2D");
169 static_assert(static_cast<int32_t>(OperationType::LOCAL_RESPONSE_NORMALIZATION) ==
170                           ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION,
171               "OperationType::LOCAL_RESPONSE_NORMALIZATION != "
172               "ANEURALNETWORKS_LOCAL_RESPONSE_NORMALIZATION");
173 static_assert(static_cast<int32_t>(OperationType::LOGISTIC) == ANEURALNETWORKS_LOGISTIC,
174               "OperationType::LOGISTIC != ANEURALNETWORKS_LOGISTIC");
175 static_assert(static_cast<int32_t>(OperationType::LSH_PROJECTION) ==
176                           ANEURALNETWORKS_LSH_PROJECTION,
177               "OperationType::LSH_PROJECTION != ANEURALNETWORKS_LSH_PROJECTION");
178 static_assert(static_cast<int32_t>(OperationType::LSTM) == ANEURALNETWORKS_LSTM,
179               "OperationType::LSTM != ANEURALNETWORKS_LSTM");
180 static_assert(static_cast<int32_t>(OperationType::MAX_POOL_2D) == ANEURALNETWORKS_MAX_POOL_2D,
181               "OperationType::MAX_POOL_2D != ANEURALNETWORKS_MAX_POOL_2D");
182 static_assert(static_cast<int32_t>(OperationType::MUL) == ANEURALNETWORKS_MUL,
183               "OperationType::MUL != ANEURALNETWORKS_MUL");
184 static_assert(static_cast<int32_t>(OperationType::RELU) == ANEURALNETWORKS_RELU,
185               "OperationType::RELU != ANEURALNETWORKS_RELU");
186 static_assert(static_cast<int32_t>(OperationType::RELU1) == ANEURALNETWORKS_RELU1,
187               "OperationType::RELU1 != ANEURALNETWORKS_RELU1");
188 static_assert(static_cast<int32_t>(OperationType::RELU6) == ANEURALNETWORKS_RELU6,
189               "OperationType::RELU6 != ANEURALNETWORKS_RELU6");
190 static_assert(static_cast<int32_t>(OperationType::RESHAPE) == ANEURALNETWORKS_RESHAPE,
191               "OperationType::RESHAPE != ANEURALNETWORKS_RESHAPE");
192 static_assert(static_cast<int32_t>(OperationType::RESIZE_BILINEAR) ==
193                           ANEURALNETWORKS_RESIZE_BILINEAR,
194               "OperationType::RESIZE_BILINEAR != ANEURALNETWORKS_RESIZE_BILINEAR");
195 static_assert(static_cast<int32_t>(OperationType::RNN) == ANEURALNETWORKS_RNN,
196               "OperationType::RNN != ANEURALNETWORKS_RNN");
197 static_assert(static_cast<int32_t>(OperationType::SOFTMAX) == ANEURALNETWORKS_SOFTMAX,
198               "OperationType::SOFTMAX != ANEURALNETWORKS_SOFTMAX");
199 static_assert(static_cast<int32_t>(OperationType::SPACE_TO_DEPTH) ==
200                           ANEURALNETWORKS_SPACE_TO_DEPTH,
201               "OperationType::SPACE_TO_DEPTH != ANEURALNETWORKS_SPACE_TO_DEPTH");
202 static_assert(static_cast<int32_t>(OperationType::SVDF) == ANEURALNETWORKS_SVDF,
203               "OperationType::SVDF != ANEURALNETWORKS_SVDF");
204 static_assert(static_cast<int32_t>(OperationType::TANH) == ANEURALNETWORKS_TANH,
205               "OperationType::TANH != ANEURALNETWORKS_TANH");
206 
207 static_assert(static_cast<int32_t>(FusedActivationFunc::NONE) == ANEURALNETWORKS_FUSED_NONE,
208               "FusedActivationFunc::NONE != ANEURALNETWORKS_FUSED_NONE");
209 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU) == ANEURALNETWORKS_FUSED_RELU,
210               "FusedActivationFunc::RELU != ANEURALNETWORKS_FUSED_RELU");
211 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU1) == ANEURALNETWORKS_FUSED_RELU1,
212               "FusedActivationFunc::RELU1 != ANEURALNETWORKS_FUSED_RELU1");
213 static_assert(static_cast<int32_t>(FusedActivationFunc::RELU6) == ANEURALNETWORKS_FUSED_RELU6,
214               "FusedActivationFunc::RELU6 != ANEURALNETWORKS_FUSED_RELU6");
215 
216 using android::sp;
217 using namespace android::nn;
218 
ANeuralNetworksMemory_createFromFd(size_t size,int prot,int fd,size_t offset,ANeuralNetworksMemory ** memory)219 int ANeuralNetworksMemory_createFromFd(size_t size, int prot, int fd, size_t offset,
220                                        ANeuralNetworksMemory** memory) {
221     *memory = nullptr;
222     std::unique_ptr<MemoryFd> m = std::make_unique<MemoryFd>();
223     if (m == nullptr) {
224         return ANEURALNETWORKS_OUT_OF_MEMORY;
225     }
226     int n = m->set(size, prot, fd, offset);
227     if (n != ANEURALNETWORKS_NO_ERROR) {
228         return n;
229     }
230     *memory = reinterpret_cast<ANeuralNetworksMemory*>(m.release());
231     return ANEURALNETWORKS_NO_ERROR;
232 }
233 
ANeuralNetworksMemory_free(ANeuralNetworksMemory * memory)234 void ANeuralNetworksMemory_free(ANeuralNetworksMemory* memory) {
235     // No validation.  Free of nullptr is valid.
236     Memory* m = reinterpret_cast<Memory*>(memory);
237     delete m;
238 }
239 
ANeuralNetworksModel_create(ANeuralNetworksModel ** model)240 int ANeuralNetworksModel_create(ANeuralNetworksModel** model) {
241     initVLogMask();
242     if (!model) {
243         LOG(ERROR) << "ANeuralNetworksModel_create passed a nullptr";
244         return ANEURALNETWORKS_UNEXPECTED_NULL;
245     }
246     ModelBuilder* m = new ModelBuilder();
247     if (m == nullptr) {
248         *model = nullptr;
249         return ANEURALNETWORKS_OUT_OF_MEMORY;
250     }
251     *model = reinterpret_cast<ANeuralNetworksModel*>(m);
252     return ANEURALNETWORKS_NO_ERROR;
253 }
254 
ANeuralNetworksModel_free(ANeuralNetworksModel * model)255 void ANeuralNetworksModel_free(ANeuralNetworksModel* model) {
256     // No validation.  Free of nullptr is valid.
257     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
258     delete m;
259 }
260 
ANeuralNetworksModel_finish(ANeuralNetworksModel * model)261 int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
262     if (!model) {
263         LOG(ERROR) << "ANeuralNetworksModel_finish passed a nullptr";
264         return ANEURALNETWORKS_UNEXPECTED_NULL;
265     }
266     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
267     return m->finish();
268 }
269 
ANeuralNetworksModel_addOperand(ANeuralNetworksModel * model,const ANeuralNetworksOperandType * type)270 int ANeuralNetworksModel_addOperand(ANeuralNetworksModel* model,
271                                     const ANeuralNetworksOperandType* type) {
272     if (!model || !type) {
273         LOG(ERROR) << "ANeuralNetworksModel_addOperand passed a nullptr";
274         return ANEURALNETWORKS_UNEXPECTED_NULL;
275     }
276     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
277     return m->addOperand(*type);
278 }
279 
ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel * model,int32_t index,const void * buffer,size_t length)280 int ANeuralNetworksModel_setOperandValue(ANeuralNetworksModel* model, int32_t index,
281                                          const void* buffer, size_t length) {
282     if (!model || !buffer) {
283         LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
284         return ANEURALNETWORKS_UNEXPECTED_NULL;
285     }
286     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
287     return m->setOperandValue(index, buffer, length);
288 }
289 
ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel * model,int32_t index,const ANeuralNetworksMemory * memory,size_t offset,size_t length)290 int ANeuralNetworksModel_setOperandValueFromMemory(ANeuralNetworksModel* model, int32_t index,
291                                                    const ANeuralNetworksMemory* memory,
292                                                    size_t offset, size_t length) {
293     if (!model || !memory) {
294         LOG(ERROR) << "ANeuralNetworksModel_setOperandValue passed a nullptr";
295         return ANEURALNETWORKS_UNEXPECTED_NULL;
296     }
297     const Memory* mem = reinterpret_cast<const Memory*>(memory);
298     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
299     return m->setOperandValueFromMemory(index, mem, offset, length);
300 }
301 
ANeuralNetworksModel_addOperation(ANeuralNetworksModel * model,ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)302 int ANeuralNetworksModel_addOperation(ANeuralNetworksModel* model,
303                                       ANeuralNetworksOperationType type, uint32_t inputCount,
304                                       const uint32_t* inputs, uint32_t outputCount,
305                                       const uint32_t* outputs) {
306     if (!model || !inputs || !outputs) {
307         LOG(ERROR) << "ANeuralNetworksModel_addOperation passed a nullptr";
308         return ANEURALNETWORKS_UNEXPECTED_NULL;
309     }
310     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
311     return m->addOperation(type, inputCount, inputs, outputCount, outputs);
312 }
313 
ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel * model,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)314 int ANeuralNetworksModel_identifyInputsAndOutputs(ANeuralNetworksModel* model, uint32_t inputCount,
315                                                   const uint32_t* inputs, uint32_t outputCount,
316                                                   const uint32_t* outputs) {
317     if (!model || !inputs || !outputs) {
318         LOG(ERROR) << ("ANeuralNetworksModel_identifyInputsAndOutputs passed a nullptr");
319         return ANEURALNETWORKS_UNEXPECTED_NULL;
320     }
321     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
322     return m->identifyInputsAndOutputs(inputCount, inputs, outputCount, outputs);
323 }
324 
ANeuralNetworksCompilation_create(ANeuralNetworksModel * model,ANeuralNetworksCompilation ** compilation)325 int ANeuralNetworksCompilation_create(ANeuralNetworksModel* model,
326                                       ANeuralNetworksCompilation** compilation) {
327     if (!model || !compilation) {
328         LOG(ERROR) << "ANeuralNetworksCompilation_create passed a nullptr";
329         return ANEURALNETWORKS_UNEXPECTED_NULL;
330     }
331 
332     ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
333     CompilationBuilder* c = nullptr;
334     int result = m->createCompilation(&c);
335     *compilation = reinterpret_cast<ANeuralNetworksCompilation*>(c);
336     return result;
337 }
338 
ANeuralNetworksCompilation_free(ANeuralNetworksCompilation * compilation)339 void ANeuralNetworksCompilation_free(ANeuralNetworksCompilation* compilation) {
340     // No validation.  Free of nullptr is valid.
341     // TODO specification says that a compilation-in-flight can be deleted
342     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
343     delete c;
344 }
345 
ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation * compilation,int32_t preference)346 int ANeuralNetworksCompilation_setPreference(ANeuralNetworksCompilation* compilation,
347                                              int32_t preference) {
348     if (!compilation) {
349         LOG(ERROR) << "ANeuralNetworksCompilation_setPreference passed a nullptr";
350         return ANEURALNETWORKS_UNEXPECTED_NULL;
351     }
352     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
353     return c->setPreference(preference);
354 }
355 
ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation * compilation)356 int ANeuralNetworksCompilation_finish(ANeuralNetworksCompilation* compilation) {
357     if (!compilation) {
358         LOG(ERROR) << "ANeuralNetworksCompilation_finish passed a nullptr";
359         return ANEURALNETWORKS_UNEXPECTED_NULL;
360     }
361     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
362     return c->finish();
363 }
364 
ANeuralNetworksExecution_create(ANeuralNetworksCompilation * compilation,ANeuralNetworksExecution ** execution)365 int ANeuralNetworksExecution_create(ANeuralNetworksCompilation* compilation,
366                                     ANeuralNetworksExecution** execution) {
367     if (!compilation || !execution) {
368         LOG(ERROR) << "ANeuralNetworksExecution_create passed a nullptr";
369         return ANEURALNETWORKS_UNEXPECTED_NULL;
370     }
371 
372     CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation);
373     ExecutionBuilder* r = nullptr;
374     int result = c->createExecution(&r);
375     *execution = reinterpret_cast<ANeuralNetworksExecution*>(r);
376     return result;
377 }
378 
ANeuralNetworksExecution_free(ANeuralNetworksExecution * execution)379 void ANeuralNetworksExecution_free(ANeuralNetworksExecution* execution) {
380     // TODO specification says that an execution-in-flight can be deleted
381     // No validation.  Free of nullptr is valid.
382     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
383     delete r;
384 }
385 
ANeuralNetworksExecution_setInput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const void * buffer,size_t length)386 int ANeuralNetworksExecution_setInput(ANeuralNetworksExecution* execution, int32_t index,
387                                       const ANeuralNetworksOperandType* type, const void* buffer,
388                                       size_t length) {
389     // TODO: For a non-optional input, also verify that buffer is not null.
390     if (!execution) {
391         LOG(ERROR) << "ANeuralNetworksExecution_setInput passed a nullptr";
392         return ANEURALNETWORKS_UNEXPECTED_NULL;
393     }
394     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
395     return r->setInput(index, type, buffer, length);
396 }
397 
ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)398 int ANeuralNetworksExecution_setInputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
399                                                 const ANeuralNetworksOperandType* type,
400                                                 const ANeuralNetworksMemory* memory, size_t offset,
401                                                 size_t length) {
402     if (!execution || !memory) {
403         LOG(ERROR) << "ANeuralNetworksExecution_setInputFromMemory passed a nullptr";
404         return ANEURALNETWORKS_UNEXPECTED_NULL;
405     }
406 
407     const Memory* m = reinterpret_cast<const Memory*>(memory);
408     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
409     return r->setInputFromMemory(index, type, m, offset, length);
410 }
411 
ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,void * buffer,size_t length)412 int ANeuralNetworksExecution_setOutput(ANeuralNetworksExecution* execution, int32_t index,
413                                        const ANeuralNetworksOperandType* type, void* buffer,
414                                        size_t length) {
415     if (!execution || !buffer) {
416         LOG(ERROR) << "ANeuralNetworksExecution_setOutput passed a nullptr";
417         return ANEURALNETWORKS_UNEXPECTED_NULL;
418     }
419     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
420     return r->setOutput(index, type, buffer, length);
421 }
422 
ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution * execution,int32_t index,const ANeuralNetworksOperandType * type,const ANeuralNetworksMemory * memory,size_t offset,size_t length)423 int ANeuralNetworksExecution_setOutputFromMemory(ANeuralNetworksExecution* execution, int32_t index,
424                                                  const ANeuralNetworksOperandType* type,
425                                                  const ANeuralNetworksMemory* memory, size_t offset,
426                                                  size_t length) {
427     if (!execution || !memory) {
428         LOG(ERROR) << "ANeuralNetworksExecution_setOutputFromMemory passed a nullptr";
429         return ANEURALNETWORKS_UNEXPECTED_NULL;
430     }
431 
432     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
433     const Memory* m = reinterpret_cast<const Memory*>(memory);
434     return r->setOutputFromMemory(index, type, m, offset, length);
435 }
436 
ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution * execution,ANeuralNetworksEvent ** event)437 int ANeuralNetworksExecution_startCompute(ANeuralNetworksExecution* execution,
438                                           ANeuralNetworksEvent** event) {
439     if (!execution || !event) {
440         LOG(ERROR) << "ANeuralNetworksExecution_startCompute passed a nullptr";
441         return ANEURALNETWORKS_UNEXPECTED_NULL;
442     }
443     // TODO validate the rest
444 
445     ExecutionBuilder* r = reinterpret_cast<ExecutionBuilder*>(execution);
446 
447     // Dynamically allocate an sp to wrap an ExecutionCallback, seen in the NN
448     // API as an abstract event object. The sp<ExecutionCallback> object is
449     // returned when the execution has been successfully launched, otherwise a
450     // nullptr is returned. The sp is used for ref-counting purposes. Without
451     // it, the HIDL service could attempt to communicate with a dead callback
452     // object.
453     std::unique_ptr<sp<ExecutionCallback>> e = std::make_unique<sp<ExecutionCallback>>();
454     *event = nullptr;
455 
456     int n = r->startCompute(e.get());
457     if (n != ANEURALNETWORKS_NO_ERROR) {
458         return n;
459     }
460     *event = reinterpret_cast<ANeuralNetworksEvent*>(e.release());
461     return ANEURALNETWORKS_NO_ERROR;
462 }
463 
ANeuralNetworksEvent_wait(ANeuralNetworksEvent * event)464 int ANeuralNetworksEvent_wait(ANeuralNetworksEvent* event) {
465     if (event == nullptr) {
466         LOG(ERROR) << "ANeuralNetworksEvent_wait passed a nullptr";
467         return ANEURALNETWORKS_UNEXPECTED_NULL;
468     }
469 
470     sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
471     (*e)->wait();
472     return ANEURALNETWORKS_NO_ERROR;
473 }
474 
ANeuralNetworksEvent_free(ANeuralNetworksEvent * event)475 void ANeuralNetworksEvent_free(ANeuralNetworksEvent* event) {
476     // No validation.  Free of nullptr is valid.
477     if (event) {
478         sp<ExecutionCallback>* e = reinterpret_cast<sp<ExecutionCallback>*>(event);
479         (*e)->wait();
480         delete e;
481     }
482 }
483