• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Provides C++ classes to more easily use the Neural Networks API.
18 
19 #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
20 #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
21 
22 #include "NeuralNetworks.h"
23 
24 #include <math.h>
25 #include <vector>
26 
27 namespace android {
28 namespace nn {
29 namespace wrapper {
30 
31 enum class Type {
32     FLOAT32 = ANEURALNETWORKS_FLOAT32,
33     INT32 = ANEURALNETWORKS_INT32,
34     UINT32 = ANEURALNETWORKS_UINT32,
35     TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32,
36     TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32,
37     TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM,
38 };
39 
40 enum class ExecutePreference {
41     PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER,
42     PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER,
43     PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED
44 };
45 
46 enum class Result {
47     NO_ERROR = ANEURALNETWORKS_NO_ERROR,
48     OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY,
49     INCOMPLETE = ANEURALNETWORKS_INCOMPLETE,
50     UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL,
51     BAD_DATA = ANEURALNETWORKS_BAD_DATA,
52 };
53 
54 struct OperandType {
55     ANeuralNetworksOperandType operandType;
56     // int32_t type;
57     std::vector<uint32_t> dimensions;
58 
59     OperandType(Type type, const std::vector<uint32_t>& d, float scale = 0.0f,
60                 int32_t zeroPoint = 0)
dimensionsOperandType61         : dimensions(d) {
62         operandType.type = static_cast<int32_t>(type);
63         operandType.scale = scale;
64         operandType.zeroPoint = zeroPoint;
65 
66         operandType.dimensionCount = static_cast<uint32_t>(dimensions.size());
67         operandType.dimensions = dimensions.data();
68     }
69 };
70 
71 class Memory {
72 public:
Memory(size_t size,int protect,int fd,size_t offset)73     Memory(size_t size, int protect, int fd, size_t offset) {
74         mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
75                  ANEURALNETWORKS_NO_ERROR;
76     }
77 
~Memory()78     ~Memory() { ANeuralNetworksMemory_free(mMemory); }
79 
80     // Disallow copy semantics to ensure the runtime object can only be freed
81     // once. Copy semantics could be enabled if some sort of reference counting
82     // or deep-copy system for runtime objects is added later.
83     Memory(const Memory&) = delete;
84     Memory& operator=(const Memory&) = delete;
85 
86     // Move semantics to remove access to the runtime object from the wrapper
87     // object that is being moved. This ensures the runtime object will be
88     // freed only once.
Memory(Memory && other)89     Memory(Memory&& other) { *this = std::move(other); }
90     Memory& operator=(Memory&& other) {
91         if (this != &other) {
92             mMemory = other.mMemory;
93             mValid = other.mValid;
94             other.mMemory = nullptr;
95             other.mValid = false;
96         }
97         return *this;
98     }
99 
get()100     ANeuralNetworksMemory* get() const { return mMemory; }
isValid()101     bool isValid() const { return mValid; }
102 
103 private:
104     ANeuralNetworksMemory* mMemory = nullptr;
105     bool mValid = true;
106 };
107 
108 class Model {
109 public:
Model()110     Model() {
111         // TODO handle the value returned by this call
112         ANeuralNetworksModel_create(&mModel);
113     }
~Model()114     ~Model() { ANeuralNetworksModel_free(mModel); }
115 
116     // Disallow copy semantics to ensure the runtime object can only be freed
117     // once. Copy semantics could be enabled if some sort of reference counting
118     // or deep-copy system for runtime objects is added later.
119     Model(const Model&) = delete;
120     Model& operator=(const Model&) = delete;
121 
122     // Move semantics to remove access to the runtime object from the wrapper
123     // object that is being moved. This ensures the runtime object will be
124     // freed only once.
Model(Model && other)125     Model(Model&& other) { *this = std::move(other); }
126     Model& operator=(Model&& other) {
127         if (this != &other) {
128             mModel = other.mModel;
129             mNextOperandId = other.mNextOperandId;
130             mValid = other.mValid;
131             other.mModel = nullptr;
132             other.mNextOperandId = 0;
133             other.mValid = false;
134         }
135         return *this;
136     }
137 
finish()138     Result finish() { return static_cast<Result>(ANeuralNetworksModel_finish(mModel)); }
139 
addOperand(const OperandType * type)140     uint32_t addOperand(const OperandType* type) {
141         if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
142             ANEURALNETWORKS_NO_ERROR) {
143             mValid = false;
144         }
145         return mNextOperandId++;
146     }
147 
setOperandValue(uint32_t index,const void * buffer,size_t length)148     void setOperandValue(uint32_t index, const void* buffer, size_t length) {
149         if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
150             ANEURALNETWORKS_NO_ERROR) {
151             mValid = false;
152         }
153     }
154 
setOperandValueFromMemory(uint32_t index,const Memory * memory,uint32_t offset,size_t length)155     void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
156                                    size_t length) {
157         if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
158                                                            length) != ANEURALNETWORKS_NO_ERROR) {
159             mValid = false;
160         }
161     }
162 
addOperation(ANeuralNetworksOperationType type,const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)163     void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
164                       const std::vector<uint32_t>& outputs) {
165         if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
166                                               inputs.data(), static_cast<uint32_t>(outputs.size()),
167                                               outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
168             mValid = false;
169         }
170     }
identifyInputsAndOutputs(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)171     void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs,
172                                   const std::vector<uint32_t>& outputs) {
173         if (ANeuralNetworksModel_identifyInputsAndOutputs(
174                         mModel, static_cast<uint32_t>(inputs.size()), inputs.data(),
175                         static_cast<uint32_t>(outputs.size()),
176                         outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
177             mValid = false;
178         }
179     }
getHandle()180     ANeuralNetworksModel* getHandle() const { return mModel; }
isValid()181     bool isValid() const { return mValid; }
182 
183 private:
184     ANeuralNetworksModel* mModel = nullptr;
185     // We keep track of the operand ID as a convenience to the caller.
186     uint32_t mNextOperandId = 0;
187     bool mValid = true;
188 };
189 
190 class Event {
191 public:
Event()192     Event() {}
~Event()193     ~Event() { ANeuralNetworksEvent_free(mEvent); }
194 
195     // Disallow copy semantics to ensure the runtime object can only be freed
196     // once. Copy semantics could be enabled if some sort of reference counting
197     // or deep-copy system for runtime objects is added later.
198     Event(const Event&) = delete;
199     Event& operator=(const Event&) = delete;
200 
201     // Move semantics to remove access to the runtime object from the wrapper
202     // object that is being moved. This ensures the runtime object will be
203     // freed only once.
Event(Event && other)204     Event(Event&& other) { *this = std::move(other); }
205     Event& operator=(Event&& other) {
206         if (this != &other) {
207             mEvent = other.mEvent;
208             other.mEvent = nullptr;
209         }
210         return *this;
211     }
212 
wait()213     Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); }
214 
215     // Only for use by Execution
set(ANeuralNetworksEvent * newEvent)216     void set(ANeuralNetworksEvent* newEvent) {
217         ANeuralNetworksEvent_free(mEvent);
218         mEvent = newEvent;
219     }
220 
221 private:
222     ANeuralNetworksEvent* mEvent = nullptr;
223 };
224 
225 class Compilation {
226 public:
Compilation(const Model * model)227     Compilation(const Model* model) {
228         int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
229         if (result != 0) {
230             // TODO Handle the error
231         }
232     }
233 
~Compilation()234     ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); }
235 
236     Compilation(const Compilation&) = delete;
237     Compilation& operator=(const Compilation&) = delete;
238 
Compilation(Compilation && other)239     Compilation(Compilation&& other) { *this = std::move(other); }
240     Compilation& operator=(Compilation&& other) {
241         if (this != &other) {
242             mCompilation = other.mCompilation;
243             other.mCompilation = nullptr;
244         }
245         return *this;
246     }
247 
setPreference(ExecutePreference preference)248     Result setPreference(ExecutePreference preference) {
249         return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
250                     mCompilation, static_cast<int32_t>(preference)));
251     }
252 
finish()253     Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); }
254 
getHandle()255     ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
256 
257 private:
258     ANeuralNetworksCompilation* mCompilation = nullptr;
259 };
260 
261 class Execution {
262 public:
Execution(const Compilation * compilation)263     Execution(const Compilation* compilation) {
264         int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution);
265         if (result != 0) {
266             // TODO Handle the error
267         }
268     }
269 
~Execution()270     ~Execution() { ANeuralNetworksExecution_free(mExecution); }
271 
272     // Disallow copy semantics to ensure the runtime object can only be freed
273     // once. Copy semantics could be enabled if some sort of reference counting
274     // or deep-copy system for runtime objects is added later.
275     Execution(const Execution&) = delete;
276     Execution& operator=(const Execution&) = delete;
277 
278     // Move semantics to remove access to the runtime object from the wrapper
279     // object that is being moved. This ensures the runtime object will be
280     // freed only once.
Execution(Execution && other)281     Execution(Execution&& other) { *this = std::move(other); }
282     Execution& operator=(Execution&& other) {
283         if (this != &other) {
284             mExecution = other.mExecution;
285             other.mExecution = nullptr;
286         }
287         return *this;
288     }
289 
290     Result setInput(uint32_t index, const void* buffer, size_t length,
291                     const ANeuralNetworksOperandType* type = nullptr) {
292         return static_cast<Result>(
293                     ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length));
294     }
295 
296     Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
297                               uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
298         return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory(
299                     mExecution, index, type, memory->get(), offset, length));
300     }
301 
302     Result setOutput(uint32_t index, void* buffer, size_t length,
303                      const ANeuralNetworksOperandType* type = nullptr) {
304         return static_cast<Result>(
305                     ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length));
306     }
307 
308     Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
309                                uint32_t length, const ANeuralNetworksOperandType* type = nullptr) {
310         return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory(
311                     mExecution, index, type, memory->get(), offset, length));
312     }
313 
startCompute(Event * event)314     Result startCompute(Event* event) {
315         ANeuralNetworksEvent* ev = nullptr;
316         Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev));
317         event->set(ev);
318         return result;
319     }
320 
compute()321     Result compute() {
322         ANeuralNetworksEvent* event = nullptr;
323         Result result =
324                     static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
325         if (result != Result::NO_ERROR) {
326             return result;
327         }
328         // TODO how to manage the lifetime of events when multiple waiters is not
329         // clear.
330         result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
331         ANeuralNetworksEvent_free(event);
332         return result;
333     }
334 
335 private:
336     ANeuralNetworksExecution* mExecution = nullptr;
337 };
338 
339 }  // namespace wrapper
340 }  // namespace nn
341 }  // namespace android
342 
343 #endif  //  ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
344