• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Provides C++ classes to more easily use the Neural Networks API.
18 // TODO(b/117845862): this should be auto generated from NeuralNetworksWrapper.h.
19 
20 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_SL_SUPPORT_LIBRARY_WRAPPER_H
21 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_SL_SUPPORT_LIBRARY_WRAPPER_H
22 
23 #include <android-base/unique_fd.h>
24 #include <android/hardware_buffer.h>
25 #include <math.h>
26 #include <unistd.h>
27 
28 #include <algorithm>
29 #include <memory>
30 #include <optional>
31 #include <string>
32 #include <utility>
33 #include <vector>
34 
35 #include "NeuralNetworksWrapper.h"
36 #include "SupportLibrary.h"
37 
38 using namespace ::android::nn::wrapper;
39 
40 namespace android {
41 namespace nn {
42 namespace sl_wrapper {
43 
44 using ::android::nn::wrapper::Duration;
45 using ::android::nn::wrapper::OperandType;
46 using ::android::nn::wrapper::Result;
47 
48 class Memory {
49    public:
50     // Takes ownership of a ANeuralNetworksMemory
Memory(const NnApiSupportLibrary * nnapi,ANeuralNetworksMemory * memory)51     Memory(const NnApiSupportLibrary* nnapi, ANeuralNetworksMemory* memory)
52         : mNnApi(nnapi), mMemory(memory), mSize(0) {}
53 
54     // Create from a FD and may takes ownership of the fd.
55     Memory(const NnApiSupportLibrary* nnapi, size_t size, int protect, int fd, size_t offset,
56            bool ownsFd = false)
mNnApi(nnapi)57         : mNnApi(nnapi), mOwnedFd(ownsFd ? std::optional<int>{fd} : std::nullopt), mSize(size) {
58         mValid = mNnApi->ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
59                  ANEURALNETWORKS_NO_ERROR;
60     }
61 
62     // Create from a buffer, may take ownership.
Memory(const NnApiSupportLibrary * nnapi,AHardwareBuffer * buffer,bool ownAHWB,size_t size)63     Memory(const NnApiSupportLibrary* nnapi, AHardwareBuffer* buffer, bool ownAHWB, size_t size)
64         : mNnApi(nnapi), mOwnedAHWB(ownAHWB ? buffer : nullptr), mSize(size) {
65         mValid = mNnApi->ANeuralNetworksMemory_createFromAHardwareBuffer(buffer, &mMemory) ==
66                  ANEURALNETWORKS_NO_ERROR;
67     }
68 
69     // Create from a desc
Memory(const NnApiSupportLibrary * nnapi,ANeuralNetworksMemoryDesc * desc,size_t size)70     Memory(const NnApiSupportLibrary* nnapi, ANeuralNetworksMemoryDesc* desc, size_t size)
71         : mNnApi(nnapi), mSize(size) {
72         mValid = mNnApi->ANeuralNetworksMemory_createFromDesc(desc, &mMemory) ==
73                  ANEURALNETWORKS_NO_ERROR;
74     }
75 
~Memory()76     virtual ~Memory() {
77         if (mMemory) {
78             mNnApi->ANeuralNetworksMemory_free(mMemory);
79         }
80         if (mOwnedFd) {
81             close(*mOwnedFd);
82         }
83         if (mOwnedAHWB) {
84             AHardwareBuffer_release(mOwnedAHWB);
85         }
86     }
87 
88     // Disallow copy semantics to ensure the runtime object can only be freed
89     // once. Copy semantics could be enabled if some sort of reference counting
90     // or deep-copy system for runtime objects is added later.
91     Memory(const Memory&) = delete;
92     Memory& operator=(const Memory&) = delete;
93 
94     // Move semantics to remove access to the runtime object from the wrapper
95     // object that is being moved. This ensures the runtime object will be
96     // freed only once.
Memory(Memory && other)97     Memory(Memory&& other) { *this = std::move(other); }
98     Memory& operator=(Memory&& other) {
99         if (this != &other) {
100             if (mMemory) {
101                 mNnApi->ANeuralNetworksMemory_free(mMemory);
102             }
103             if (mOwnedFd) {
104                 close(*mOwnedFd);
105             }
106             if (mOwnedAHWB) {
107                 AHardwareBuffer_release(mOwnedAHWB);
108             }
109 
110             mMemory = other.mMemory;
111             mValid = other.mValid;
112             mNnApi = other.mNnApi;
113             mOwnedFd = other.mOwnedFd;
114             mOwnedAHWB = other.mOwnedAHWB;
115             other.mMemory = nullptr;
116             other.mValid = false;
117             other.mOwnedFd.reset();
118             other.mOwnedAHWB = nullptr;
119         }
120         return *this;
121     }
122 
get()123     ANeuralNetworksMemory* get() const { return mMemory; }
isValid()124     bool isValid() const { return mValid; }
getSize()125     size_t getSize() const { return mSize; }
copyTo(Memory & other)126     Result copyTo(Memory& other) {
127         return static_cast<Result>(mNnApi->ANeuralNetworksMemory_copy(mMemory, other.mMemory));
128     }
129 
130    private:
131     const NnApiSupportLibrary* mNnApi = nullptr;
132     ANeuralNetworksMemory* mMemory = nullptr;
133     bool mValid = true;
134     std::optional<int> mOwnedFd;
135     AHardwareBuffer* mOwnedAHWB = nullptr;
136     size_t mSize;
137 };
138 
139 class Model {
140    public:
Model(const NnApiSupportLibrary * nnapi)141     Model(const NnApiSupportLibrary* nnapi) : mNnApi(nnapi) {
142         mValid = mNnApi->ANeuralNetworksModel_create(&mModel) == ANEURALNETWORKS_NO_ERROR;
143     }
~Model()144     ~Model() {
145         if (mModel) {
146             mNnApi->ANeuralNetworksModel_free(mModel);
147         }
148     }
149 
150     // Disallow copy semantics to ensure the runtime object can only be freed
151     // once. Copy semantics could be enabled if some sort of reference counting
152     // or deep-copy system for runtime objects is added later.
153     Model(const Model&) = delete;
154     Model& operator=(const Model&) = delete;
155 
156     // Move semantics to remove access to the runtime object from the wrapper
157     // object that is being moved. This ensures the runtime object will be
158     // freed only once.
Model(Model && other)159     Model(Model&& other) { *this = std::move(other); }
160     Model& operator=(Model&& other) {
161         if (this != &other) {
162             if (mModel != nullptr) {
163                 mNnApi->ANeuralNetworksModel_free(mModel);
164             }
165             mNnApi = other.mNnApi;
166             mModel = other.mModel;
167             mNextOperandId = other.mNextOperandId;
168             mValid = other.mValid;
169             mRelaxed = other.mRelaxed;
170             mFinished = other.mFinished;
171             mOperands = std::move(other.mOperands);
172             mInputs = std::move(other.mInputs);
173             mOutputs = std::move(other.mOutputs);
174             other.mModel = nullptr;
175             other.mNextOperandId = 0;
176             other.mValid = false;
177             other.mRelaxed = false;
178             other.mFinished = false;
179         }
180         return *this;
181     }
182 
finish()183     Result finish() {
184         if (mValid) {
185             auto result = static_cast<Result>(mNnApi->ANeuralNetworksModel_finish(mModel));
186             if (result != Result::NO_ERROR) {
187                 mValid = false;
188             }
189             mFinished = true;
190             return result;
191         } else {
192             return Result::BAD_STATE;
193         }
194     }
195 
addOperand(const OperandType * type)196     uint32_t addOperand(const OperandType* type) {
197         if (mNnApi->ANeuralNetworksModel_addOperand(mModel, &type->operandType) !=
198             ANEURALNETWORKS_NO_ERROR) {
199             mValid = false;
200         } else {
201             mOperands.push_back(*type);
202         }
203 
204         if (type->channelQuant) {
205             if (mNnApi->ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
206                         mModel, mNextOperandId, &type->channelQuant.value().params) !=
207                 ANEURALNETWORKS_NO_ERROR) {
208                 mValid = false;
209             }
210         }
211 
212         return mNextOperandId++;
213     }
214 
215     template <typename T>
addConstantOperand(const OperandType * type,const T & value)216     uint32_t addConstantOperand(const OperandType* type, const T& value) {
217         static_assert(sizeof(T) <= ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES,
218                       "Values larger than ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES "
219                       "not supported");
220         uint32_t index = addOperand(type);
221         setOperandValue(index, &value);
222         return index;
223     }
224 
addModelOperand(const Model * value)225     uint32_t addModelOperand(const Model* value) {
226         OperandType operandType(Type::MODEL, {});
227         uint32_t operand = addOperand(&operandType);
228         setOperandValueFromModel(operand, value);
229         return operand;
230     }
231 
setOperandValue(uint32_t index,const void * buffer,size_t length)232     void setOperandValue(uint32_t index, const void* buffer, size_t length) {
233         if (mNnApi->ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
234             ANEURALNETWORKS_NO_ERROR) {
235             mValid = false;
236         }
237     }
238 
239     template <typename T>
setOperandValue(uint32_t index,const T * value)240     void setOperandValue(uint32_t index, const T* value) {
241         static_assert(!std::is_pointer<T>(), "No operand may have a pointer as its value");
242         return setOperandValue(index, value, sizeof(T));
243     }
244 
setOperandValueFromMemory(uint32_t index,const Memory * memory,uint32_t offset,size_t length)245     void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
246                                    size_t length) {
247         if (mNnApi->ANeuralNetworksModel_setOperandValueFromMemory(
248                     mModel, index, memory->get(), offset, length) != ANEURALNETWORKS_NO_ERROR) {
249             mValid = false;
250         }
251     }
252 
setOperandValueFromModel(uint32_t index,const Model * value)253     void setOperandValueFromModel(uint32_t index, const Model* value) {
254         if (mNnApi->ANeuralNetworksModel_setOperandValueFromModel(mModel, index, value->mModel) !=
255             ANEURALNETWORKS_NO_ERROR) {
256             mValid = false;
257         }
258     }
259 
setOperandValueFromModel(uint32_t index,ANeuralNetworksModel * value)260     void setOperandValueFromModel(uint32_t index, ANeuralNetworksModel* value) {
261         if (mNnApi->ANeuralNetworksModel_setOperandValueFromModel(mModel, index, value) !=
262             ANEURALNETWORKS_NO_ERROR) {
263             mValid = false;
264         }
265     }
266 
addOperation(ANeuralNetworksOperationType type,const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)267     void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
268                       const std::vector<uint32_t>& outputs) {
269         if (mNnApi->ANeuralNetworksModel_addOperation(
270                     mModel, type, static_cast<uint32_t>(inputs.size()), inputs.data(),
271                     static_cast<uint32_t>(outputs.size()),
272                     outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
273             mValid = false;
274         }
275     }
identifyInputsAndOutputs(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)276     void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs,
277                                   const std::vector<uint32_t>& outputs) {
278         if (mNnApi->ANeuralNetworksModel_identifyInputsAndOutputs(
279                     mModel, static_cast<uint32_t>(inputs.size()), inputs.data(),
280                     static_cast<uint32_t>(outputs.size()),
281                     outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
282             mValid = false;
283         } else {
284             mInputs = inputs;
285             mOutputs = outputs;
286         }
287     }
288 
relaxComputationFloat32toFloat16(bool isRelax)289     void relaxComputationFloat32toFloat16(bool isRelax) {
290         if (mNnApi->ANeuralNetworksModel_relaxComputationFloat32toFloat16(mModel, isRelax) ==
291             ANEURALNETWORKS_NO_ERROR) {
292             mRelaxed = isRelax;
293         }
294     }
295 
getExtensionOperandType(const std::string & extensionName,uint16_t operandCodeWithinExtension,int32_t * type)296     void getExtensionOperandType(const std::string& extensionName,
297                                  uint16_t operandCodeWithinExtension, int32_t* type) {
298         if (mNnApi->ANeuralNetworksModel_getExtensionOperandType(
299                     mModel, extensionName.c_str(), operandCodeWithinExtension, type) !=
300             ANEURALNETWORKS_NO_ERROR) {
301             mValid = false;
302         }
303     }
304 
getExtensionOperationType(const std::string & extensionName,uint16_t operandCodeWithinExtension,ANeuralNetworksOperationType * type)305     void getExtensionOperationType(const std::string& extensionName,
306                                    uint16_t operandCodeWithinExtension,
307                                    ANeuralNetworksOperationType* type) {
308         if (mNnApi->ANeuralNetworksModel_getExtensionOperationType(
309                     mModel, extensionName.c_str(), operandCodeWithinExtension, type) !=
310             ANEURALNETWORKS_NO_ERROR) {
311             mValid = false;
312         }
313     }
314 
setOperandExtensionData(int32_t operandId,const void * data,size_t length)315     void setOperandExtensionData(int32_t operandId, const void* data, size_t length) {
316         if (mNnApi->ANeuralNetworksModel_setOperandExtensionData(mModel, operandId, data, length) !=
317             ANEURALNETWORKS_NO_ERROR) {
318             mValid = false;
319         }
320     }
321 
getHandle()322     ANeuralNetworksModel* getHandle() const { return mModel; }
isValid()323     bool isValid() const { return mValid; }
isRelaxed()324     bool isRelaxed() const { return mRelaxed; }
isFinished()325     bool isFinished() const { return mFinished; }
326 
getInputs()327     const std::vector<uint32_t>& getInputs() const { return mInputs; }
getOutputs()328     const std::vector<uint32_t>& getOutputs() const { return mOutputs; }
getOperands()329     const std::vector<OperandType>& getOperands() const { return mOperands; }
330 
331    protected:
332     const NnApiSupportLibrary* mNnApi = nullptr;
333     ANeuralNetworksModel* mModel = nullptr;
334     // We keep track of the operand ID as a convenience to the caller.
335     uint32_t mNextOperandId = 0;
336     // We keep track of the operand datatypes/dimensions as a convenience to the caller.
337     std::vector<OperandType> mOperands;
338     std::vector<uint32_t> mInputs;
339     std::vector<uint32_t> mOutputs;
340     bool mValid = true;
341     bool mRelaxed = false;
342     bool mFinished = false;
343 };
344 
345 class Compilation {
346    public:
347     // On success, createForDevice(s) will return Result::NO_ERROR and the created compilation;
348     // otherwise, it will return the error code and Compilation object wrapping a nullptr handle.
createForDevice(const NnApiSupportLibrary * nnapi,const Model * model,const ANeuralNetworksDevice * device)349     static std::pair<Result, Compilation> createForDevice(const NnApiSupportLibrary* nnapi,
350                                                           const Model* model,
351                                                           const ANeuralNetworksDevice* device) {
352         return createForDevices(nnapi, model, {device});
353     }
createForDevices(const NnApiSupportLibrary * nnapi,const Model * model,const std::vector<const ANeuralNetworksDevice * > & devices)354     static std::pair<Result, Compilation> createForDevices(
355             const NnApiSupportLibrary* nnapi, const Model* model,
356             const std::vector<const ANeuralNetworksDevice*>& devices) {
357         ANeuralNetworksCompilation* compilation = nullptr;
358         const Result result =
359                 static_cast<Result>(nnapi->ANeuralNetworksCompilation_createForDevices(
360                         model->getHandle(), devices.empty() ? nullptr : devices.data(),
361                         devices.size(), &compilation));
362         return {result, Compilation(nnapi, compilation)};
363     }
364 
~Compilation()365     ~Compilation() { mNnApi->ANeuralNetworksCompilation_free(mCompilation); }
366 
367     // Disallow copy semantics to ensure the runtime object can only be freed
368     // once. Copy semantics could be enabled if some sort of reference counting
369     // or deep-copy system for runtime objects is added later.
370     Compilation(const Compilation&) = delete;
371     Compilation& operator=(const Compilation&) = delete;
372 
373     // Move semantics to remove access to the runtime object from the wrapper
374     // object that is being moved. This ensures the runtime object will be
375     // freed only once.
Compilation(Compilation && other)376     Compilation(Compilation&& other) { *this = std::move(other); }
377     Compilation& operator=(Compilation&& other) {
378         if (this != &other) {
379             mNnApi = other.mNnApi;
380             mNnApi->ANeuralNetworksCompilation_free(mCompilation);
381             mCompilation = other.mCompilation;
382             other.mCompilation = nullptr;
383         }
384         return *this;
385     }
386 
setPreference(ExecutePreference preference)387     Result setPreference(ExecutePreference preference) {
388         return static_cast<Result>(mNnApi->ANeuralNetworksCompilation_setPreference(
389                 mCompilation, static_cast<int32_t>(preference)));
390     }
391 
setPriority(ExecutePriority priority)392     Result setPriority(ExecutePriority priority) {
393         return static_cast<Result>(mNnApi->ANeuralNetworksCompilation_setPriority(
394                 mCompilation, static_cast<int32_t>(priority)));
395     }
396 
setTimeout(uint64_t durationNs)397     Result setTimeout(uint64_t durationNs) {
398         return static_cast<Result>(
399                 mNnApi->ANeuralNetworksCompilation_setTimeout(mCompilation, durationNs));
400     }
401 
setCaching(const std::string & cacheDir,const std::vector<uint8_t> & token)402     Result setCaching(const std::string& cacheDir, const std::vector<uint8_t>& token) {
403         if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN) {
404             return Result::BAD_DATA;
405         }
406         return static_cast<Result>(mNnApi->ANeuralNetworksCompilation_setCaching(
407                 mCompilation, cacheDir.c_str(), token.data()));
408     }
409 
setCachingFromFds(const std::vector<int> & modelCacheFds,const std::vector<int> & dataCacheFds,const std::vector<uint8_t> & token)410     Result setCachingFromFds(const std::vector<int>& modelCacheFds,
411                              const std::vector<int>& dataCacheFds,
412                              const std::vector<uint8_t>& token) {
413         if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN) {
414             return Result::BAD_DATA;
415         }
416         return static_cast<Result>(mNnApi->SL_ANeuralNetworksCompilation_setCachingFromFds(
417                 mCompilation, modelCacheFds.data(), modelCacheFds.size(), dataCacheFds.data(),
418                 dataCacheFds.size(), token.data()));
419     }
420 
setCachingFromFds(const std::vector<base::unique_fd> & modelCacheOwnedFds,const std::vector<base::unique_fd> & dataCacheOwnedFds,const std::vector<uint8_t> & token)421     Result setCachingFromFds(const std::vector<base::unique_fd>& modelCacheOwnedFds,
422                              const std::vector<base::unique_fd>& dataCacheOwnedFds,
423                              const std::vector<uint8_t>& token) {
424         std::vector<int> modelCacheFds, dataCacheFds;
425         for (const auto& fd : modelCacheOwnedFds) {
426             modelCacheFds.push_back(fd.get());
427         }
428         for (const auto& fd : dataCacheOwnedFds) {
429             dataCacheFds.push_back(fd.get());
430         }
431         return setCachingFromFds(modelCacheFds, dataCacheFds, token);
432     }
433 
finish()434     Result finish() {
435         return static_cast<Result>(mNnApi->ANeuralNetworksCompilation_finish(mCompilation));
436     }
437 
getPreferredMemoryAlignmentForInput(uint32_t index,uint32_t * alignment)438     Result getPreferredMemoryAlignmentForInput(uint32_t index, uint32_t* alignment) const {
439         return static_cast<Result>(
440                 mNnApi->ANeuralNetworksCompilation_getPreferredMemoryAlignmentForInput(
441                         mCompilation, index, alignment));
442     };
443 
getPreferredMemoryPaddingForInput(uint32_t index,uint32_t * padding)444     Result getPreferredMemoryPaddingForInput(uint32_t index, uint32_t* padding) const {
445         return static_cast<Result>(
446                 mNnApi->ANeuralNetworksCompilation_getPreferredMemoryPaddingForInput(
447                         mCompilation, index, padding));
448     };
449 
getPreferredMemoryAlignmentForOutput(uint32_t index,uint32_t * alignment)450     Result getPreferredMemoryAlignmentForOutput(uint32_t index, uint32_t* alignment) const {
451         return static_cast<Result>(
452                 mNnApi->ANeuralNetworksCompilation_getPreferredMemoryAlignmentForOutput(
453                         mCompilation, index, alignment));
454     };
455 
getPreferredMemoryPaddingForOutput(uint32_t index,uint32_t * padding)456     Result getPreferredMemoryPaddingForOutput(uint32_t index, uint32_t* padding) const {
457         return static_cast<Result>(
458                 mNnApi->ANeuralNetworksCompilation_getPreferredMemoryPaddingForOutput(
459                         mCompilation, index, padding));
460     };
461 
getHandle()462     ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
463 
464    protected:
465     // Takes the ownership of ANeuralNetworksCompilation.
Compilation(const NnApiSupportLibrary * nnapi,ANeuralNetworksCompilation * compilation)466     Compilation(const NnApiSupportLibrary* nnapi, ANeuralNetworksCompilation* compilation)
467         : mNnApi(nnapi), mCompilation(compilation) {}
468 
469     const NnApiSupportLibrary* mNnApi = nullptr;
470     ANeuralNetworksCompilation* mCompilation = nullptr;
471 };
472 
473 class Execution {
474    public:
Execution(const NnApiSupportLibrary * nnapi,const Compilation * compilation)475     Execution(const NnApiSupportLibrary* nnapi, const Compilation* compilation)
476         : mNnApi(nnapi), mCompilation(compilation->getHandle()) {
477         int result = mNnApi->ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
478         if (result != 0) {
479             // TODO Handle the error
480         }
481     }
482 
~Execution()483     ~Execution() {
484         if (mExecution) {
485             mNnApi->ANeuralNetworksExecution_free(mExecution);
486         }
487     }
488 
489     // Disallow copy semantics to ensure the runtime object can only be freed
490     // once. Copy semantics could be enabled if some sort of reference counting
491     // or deep-copy system for runtime objects is added later.
492     Execution(const Execution&) = delete;
493     Execution& operator=(const Execution&) = delete;
494 
495     // Move semantics to remove access to the runtime object from the wrapper
496     // object that is being moved. This ensures the runtime object will be
497     // freed only once.
Execution(Execution && other)498     Execution(Execution&& other) { *this = std::move(other); }
499     Execution& operator=(Execution&& other) {
500         if (this != &other) {
501             if (mExecution != nullptr) {
502                 mNnApi->ANeuralNetworksExecution_free(mExecution);
503             }
504             mNnApi = other.mNnApi;
505             mCompilation = other.mCompilation;
506             mExecution = other.mExecution;
507             other.mCompilation = nullptr;
508             other.mExecution = nullptr;
509         }
510         return *this;
511     }
512 
513     Result setInput(uint32_t index, const void* buffer, size_t length,
514                     const ANeuralNetworksOperandType* type = nullptr) {
515         return static_cast<Result>(
516                 mNnApi->ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
517     }
518 
519     template <typename T>
520     Result setInput(uint32_t index, const T* value,
521                     const ANeuralNetworksOperandType* type = nullptr) {
522         static_assert(!std::is_pointer<T>(), "No operand may have a pointer as its value");
523         return setInput(index, value, sizeof(T), type);
524     }
525 
526     Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
527                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
528         return static_cast<Result>(mNnApi->ANeuralNetworksExecution_setInputFromMemory(
529                 mExecution, index, type, memory->get(), offset, length));
530     }
531 
532     Result setOutput(uint32_t index, void* buffer, size_t length,
533                      const ANeuralNetworksOperandType* type = nullptr) {
534         return static_cast<Result>(mNnApi->ANeuralNetworksExecution_setOutput(
535                 mExecution, index, type, buffer, length));
536     }
537 
538     template <typename T>
539     Result setOutput(uint32_t index, T* value, const ANeuralNetworksOperandType* type = nullptr) {
540         static_assert(!std::is_pointer<T>(), "No operand may have a pointer as its value");
541         return setOutput(index, value, sizeof(T), type);
542     }
543 
544     Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
545                                uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
546         return static_cast<Result>(mNnApi->ANeuralNetworksExecution_setOutputFromMemory(
547                 mExecution, index, type, memory->get(), offset, length));
548     }
549 
setLoopTimeout(uint64_t duration)550     Result setLoopTimeout(uint64_t duration) {
551         return static_cast<Result>(
552                 mNnApi->ANeuralNetworksExecution_setLoopTimeout(mExecution, duration));
553     }
554 
setMeasureTiming(bool measure)555     Result setMeasureTiming(bool measure) {
556         return static_cast<Result>(
557                 mNnApi->ANeuralNetworksExecution_setMeasureTiming(mExecution, measure));
558     }
559 
setTimeout(uint64_t duration)560     Result setTimeout(uint64_t duration) {
561         return static_cast<Result>(
562                 mNnApi->ANeuralNetworksExecution_setTimeout(mExecution, duration));
563     }
564 
getDuration(Duration durationCode,uint64_t * duration)565     Result getDuration(Duration durationCode, uint64_t* duration) {
566         return static_cast<Result>(mNnApi->ANeuralNetworksExecution_getDuration(
567                 mExecution, static_cast<int32_t>(durationCode), duration));
568     }
569 
enableInputAndOutputPadding(bool enable)570     Result enableInputAndOutputPadding(bool enable) {
571         return static_cast<Result>(
572                 mNnApi->ANeuralNetworksExecution_enableInputAndOutputPadding(mExecution, enable));
573     }
574 
setReusable(bool reusable)575     Result setReusable(bool reusable) {
576         return static_cast<Result>(
577                 mNnApi->ANeuralNetworksExecution_setReusable(mExecution, reusable));
578     }
579 
580     // By default, compute() uses the synchronous API. Either an argument or
581     // setComputeMode() can be used to change the behavior of compute() to
582     // use the burst API
583     // Returns the previous ComputeMode.
584     enum class ComputeMode { SYNC, BURST, FENCED };
setComputeMode(ComputeMode mode)585     static ComputeMode setComputeMode(ComputeMode mode) {
586         ComputeMode oldComputeMode = mComputeMode;
587         mComputeMode = mode;
588         return oldComputeMode;
589     }
getComputeMode()590     static ComputeMode getComputeMode() { return mComputeMode; }
591 
592     Result compute(ComputeMode computeMode = mComputeMode) {
593         switch (computeMode) {
594             case ComputeMode::SYNC: {
595                 return static_cast<Result>(mNnApi->ANeuralNetworksExecution_compute(mExecution));
596             }
597             case ComputeMode::BURST: {
598                 ANeuralNetworksBurst* burst = nullptr;
599                 Result result = static_cast<Result>(
600                         mNnApi->ANeuralNetworksBurst_create(mCompilation, &burst));
601                 if (result != Result::NO_ERROR) {
602                     return result;
603                 }
604                 result = static_cast<Result>(
605                         mNnApi->ANeuralNetworksExecution_burstCompute(mExecution, burst));
606                 mNnApi->ANeuralNetworksBurst_free(burst);
607                 return result;
608             }
609             case ComputeMode::FENCED: {
610                 ANeuralNetworksEvent* event = nullptr;
611                 Result result = static_cast<Result>(
612                         mNnApi->ANeuralNetworksExecution_startComputeWithDependencies(
613                                 mExecution, nullptr, 0, 0, &event));
614                 if (result != Result::NO_ERROR) {
615                     return result;
616                 }
617                 result = static_cast<Result>(mNnApi->ANeuralNetworksEvent_wait(event));
618                 mNnApi->ANeuralNetworksEvent_free(event);
619                 return result;
620             }
621         }
622         return Result::BAD_DATA;
623     }
624 
startComputeWithDependencies(const std::vector<const ANeuralNetworksEvent * > & deps,uint64_t duration,Event * event)625     Result startComputeWithDependencies(const std::vector<const ANeuralNetworksEvent*>& deps,
626                                         uint64_t duration, Event* event) {
627         ANeuralNetworksEvent* ev = nullptr;
628         Result result = static_cast<Result>(
629                 NNAPI_CALL(ANeuralNetworksExecution_startComputeWithDependencies(
630                         mExecution, deps.data(), deps.size(), duration, &ev)));
631         event->set(ev);
632         return result;
633     }
634 
getOutputOperandDimensions(uint32_t index,std::vector<uint32_t> * dimensions)635     Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
636         uint32_t rank = 0;
637         Result result = static_cast<Result>(
638                 mNnApi->ANeuralNetworksExecution_getOutputOperandRank(mExecution, index, &rank));
639         dimensions->resize(rank);
640         if ((result != Result::NO_ERROR && result != Result::OUTPUT_INSUFFICIENT_SIZE) ||
641             rank == 0) {
642             return result;
643         }
644         result = static_cast<Result>(mNnApi->ANeuralNetworksExecution_getOutputOperandDimensions(
645                 mExecution, index, dimensions->data()));
646         return result;
647     }
648 
getHandle()649     ANeuralNetworksExecution* getHandle() { return mExecution; };
650 
651    private:
652     const NnApiSupportLibrary* mNnApi = nullptr;
653     ANeuralNetworksCompilation* mCompilation = nullptr;
654     ANeuralNetworksExecution* mExecution = nullptr;
655 
656     // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp.
657     static ComputeMode mComputeMode;
658 };
659 
660 }  // namespace sl_wrapper
661 }  // namespace nn
662 }  // namespace android
663 
664 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_SL_SUPPORT_LIBRARY_WRAPPER_H
665