1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Provides C++ classes to more easily use the Neural Networks API. 18 19 #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 20 #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 21 22 #include "NeuralNetworks.h" 23 24 #include <math.h> 25 #include <vector> 26 27 namespace android { 28 namespace nn { 29 namespace wrapper { 30 31 enum class Type { 32 FLOAT32 = ANEURALNETWORKS_FLOAT32, 33 INT32 = ANEURALNETWORKS_INT32, 34 UINT32 = ANEURALNETWORKS_UINT32, 35 TENSOR_FLOAT32 = ANEURALNETWORKS_TENSOR_FLOAT32, 36 TENSOR_INT32 = ANEURALNETWORKS_TENSOR_INT32, 37 TENSOR_QUANT8_ASYMM = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, 38 }; 39 40 enum class ExecutePreference { 41 PREFER_LOW_POWER = ANEURALNETWORKS_PREFER_LOW_POWER, 42 PREFER_FAST_SINGLE_ANSWER = ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER, 43 PREFER_SUSTAINED_SPEED = ANEURALNETWORKS_PREFER_SUSTAINED_SPEED 44 }; 45 46 enum class Result { 47 NO_ERROR = ANEURALNETWORKS_NO_ERROR, 48 OUT_OF_MEMORY = ANEURALNETWORKS_OUT_OF_MEMORY, 49 INCOMPLETE = ANEURALNETWORKS_INCOMPLETE, 50 UNEXPECTED_NULL = ANEURALNETWORKS_UNEXPECTED_NULL, 51 BAD_DATA = ANEURALNETWORKS_BAD_DATA, 52 }; 53 54 struct OperandType { 55 ANeuralNetworksOperandType operandType; 56 // int32_t type; 57 std::vector<uint32_t> dimensions; 58 59 OperandType(Type type, const std::vector<uint32_t>& d, float scale = 0.0f, 60 int32_t zeroPoint = 0) dimensionsOperandType61 : dimensions(d) { 62 operandType.type = static_cast<int32_t>(type); 63 operandType.scale = scale; 64 operandType.zeroPoint = zeroPoint; 65 66 operandType.dimensionCount = static_cast<uint32_t>(dimensions.size()); 67 operandType.dimensions = dimensions.data(); 68 } 69 }; 70 71 class Memory { 72 public: Memory(size_t size,int protect,int fd,size_t offset)73 Memory(size_t size, int protect, int fd, size_t offset) { 74 mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) == 75 ANEURALNETWORKS_NO_ERROR; 76 } 77 ~Memory()78 ~Memory() { ANeuralNetworksMemory_free(mMemory); } 79 80 // Disallow copy semantics to ensure the runtime object can only be freed 81 // once. Copy semantics could be enabled if some sort of reference counting 82 // or deep-copy system for runtime objects is added later. 83 Memory(const Memory&) = delete; 84 Memory& operator=(const Memory&) = delete; 85 86 // Move semantics to remove access to the runtime object from the wrapper 87 // object that is being moved. This ensures the runtime object will be 88 // freed only once. Memory(Memory && other)89 Memory(Memory&& other) { *this = std::move(other); } 90 Memory& operator=(Memory&& other) { 91 if (this != &other) { 92 mMemory = other.mMemory; 93 mValid = other.mValid; 94 other.mMemory = nullptr; 95 other.mValid = false; 96 } 97 return *this; 98 } 99 get()100 ANeuralNetworksMemory* get() const { return mMemory; } isValid()101 bool isValid() const { return mValid; } 102 103 private: 104 ANeuralNetworksMemory* mMemory = nullptr; 105 bool mValid = true; 106 }; 107 108 class Model { 109 public: Model()110 Model() { 111 // TODO handle the value returned by this call 112 ANeuralNetworksModel_create(&mModel); 113 } ~Model()114 ~Model() { ANeuralNetworksModel_free(mModel); } 115 116 // Disallow copy semantics to ensure the runtime object can only be freed 117 // once. Copy semantics could be enabled if some sort of reference counting 118 // or deep-copy system for runtime objects is added later. 119 Model(const Model&) = delete; 120 Model& operator=(const Model&) = delete; 121 122 // Move semantics to remove access to the runtime object from the wrapper 123 // object that is being moved. This ensures the runtime object will be 124 // freed only once. Model(Model && other)125 Model(Model&& other) { *this = std::move(other); } 126 Model& operator=(Model&& other) { 127 if (this != &other) { 128 mModel = other.mModel; 129 mNextOperandId = other.mNextOperandId; 130 mValid = other.mValid; 131 other.mModel = nullptr; 132 other.mNextOperandId = 0; 133 other.mValid = false; 134 } 135 return *this; 136 } 137 finish()138 Result finish() { return static_cast<Result>(ANeuralNetworksModel_finish(mModel)); } 139 addOperand(const OperandType * type)140 uint32_t addOperand(const OperandType* type) { 141 if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 142 ANEURALNETWORKS_NO_ERROR) { 143 mValid = false; 144 } 145 return mNextOperandId++; 146 } 147 setOperandValue(uint32_t index,const void * buffer,size_t length)148 void setOperandValue(uint32_t index, const void* buffer, size_t length) { 149 if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) != 150 ANEURALNETWORKS_NO_ERROR) { 151 mValid = false; 152 } 153 } 154 setOperandValueFromMemory(uint32_t index,const Memory * memory,uint32_t offset,size_t length)155 void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 156 size_t length) { 157 if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset, 158 length) != ANEURALNETWORKS_NO_ERROR) { 159 mValid = false; 160 } 161 } 162 addOperation(ANeuralNetworksOperationType type,const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)163 void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, 164 const std::vector<uint32_t>& outputs) { 165 if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()), 166 inputs.data(), static_cast<uint32_t>(outputs.size()), 167 outputs.data()) != ANEURALNETWORKS_NO_ERROR) { 168 mValid = false; 169 } 170 } identifyInputsAndOutputs(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs)171 void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs, 172 const std::vector<uint32_t>& outputs) { 173 if (ANeuralNetworksModel_identifyInputsAndOutputs( 174 mModel, static_cast<uint32_t>(inputs.size()), inputs.data(), 175 static_cast<uint32_t>(outputs.size()), 176 outputs.data()) != ANEURALNETWORKS_NO_ERROR) { 177 mValid = false; 178 } 179 } getHandle()180 ANeuralNetworksModel* getHandle() const { return mModel; } isValid()181 bool isValid() const { return mValid; } 182 183 private: 184 ANeuralNetworksModel* mModel = nullptr; 185 // We keep track of the operand ID as a convenience to the caller. 186 uint32_t mNextOperandId = 0; 187 bool mValid = true; 188 }; 189 190 class Event { 191 public: Event()192 Event() {} ~Event()193 ~Event() { ANeuralNetworksEvent_free(mEvent); } 194 195 // Disallow copy semantics to ensure the runtime object can only be freed 196 // once. Copy semantics could be enabled if some sort of reference counting 197 // or deep-copy system for runtime objects is added later. 198 Event(const Event&) = delete; 199 Event& operator=(const Event&) = delete; 200 201 // Move semantics to remove access to the runtime object from the wrapper 202 // object that is being moved. This ensures the runtime object will be 203 // freed only once. Event(Event && other)204 Event(Event&& other) { *this = std::move(other); } 205 Event& operator=(Event&& other) { 206 if (this != &other) { 207 mEvent = other.mEvent; 208 other.mEvent = nullptr; 209 } 210 return *this; 211 } 212 wait()213 Result wait() { return static_cast<Result>(ANeuralNetworksEvent_wait(mEvent)); } 214 215 // Only for use by Execution set(ANeuralNetworksEvent * newEvent)216 void set(ANeuralNetworksEvent* newEvent) { 217 ANeuralNetworksEvent_free(mEvent); 218 mEvent = newEvent; 219 } 220 221 private: 222 ANeuralNetworksEvent* mEvent = nullptr; 223 }; 224 225 class Compilation { 226 public: Compilation(const Model * model)227 Compilation(const Model* model) { 228 int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation); 229 if (result != 0) { 230 // TODO Handle the error 231 } 232 } 233 ~Compilation()234 ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); } 235 236 Compilation(const Compilation&) = delete; 237 Compilation& operator=(const Compilation&) = delete; 238 Compilation(Compilation && other)239 Compilation(Compilation&& other) { *this = std::move(other); } 240 Compilation& operator=(Compilation&& other) { 241 if (this != &other) { 242 mCompilation = other.mCompilation; 243 other.mCompilation = nullptr; 244 } 245 return *this; 246 } 247 setPreference(ExecutePreference preference)248 Result setPreference(ExecutePreference preference) { 249 return static_cast<Result>(ANeuralNetworksCompilation_setPreference( 250 mCompilation, static_cast<int32_t>(preference))); 251 } 252 finish()253 Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); } 254 getHandle()255 ANeuralNetworksCompilation* getHandle() const { return mCompilation; } 256 257 private: 258 ANeuralNetworksCompilation* mCompilation = nullptr; 259 }; 260 261 class Execution { 262 public: Execution(const Compilation * compilation)263 Execution(const Compilation* compilation) { 264 int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution); 265 if (result != 0) { 266 // TODO Handle the error 267 } 268 } 269 ~Execution()270 ~Execution() { ANeuralNetworksExecution_free(mExecution); } 271 272 // Disallow copy semantics to ensure the runtime object can only be freed 273 // once. Copy semantics could be enabled if some sort of reference counting 274 // or deep-copy system for runtime objects is added later. 275 Execution(const Execution&) = delete; 276 Execution& operator=(const Execution&) = delete; 277 278 // Move semantics to remove access to the runtime object from the wrapper 279 // object that is being moved. This ensures the runtime object will be 280 // freed only once. Execution(Execution && other)281 Execution(Execution&& other) { *this = std::move(other); } 282 Execution& operator=(Execution&& other) { 283 if (this != &other) { 284 mExecution = other.mExecution; 285 other.mExecution = nullptr; 286 } 287 return *this; 288 } 289 290 Result setInput(uint32_t index, const void* buffer, size_t length, 291 const ANeuralNetworksOperandType* type = nullptr) { 292 return static_cast<Result>( 293 ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length)); 294 } 295 296 Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 297 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 298 return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory( 299 mExecution, index, type, memory->get(), offset, length)); 300 } 301 302 Result setOutput(uint32_t index, void* buffer, size_t length, 303 const ANeuralNetworksOperandType* type = nullptr) { 304 return static_cast<Result>( 305 ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length)); 306 } 307 308 Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 309 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 310 return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory( 311 mExecution, index, type, memory->get(), offset, length)); 312 } 313 startCompute(Event * event)314 Result startCompute(Event* event) { 315 ANeuralNetworksEvent* ev = nullptr; 316 Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev)); 317 event->set(ev); 318 return result; 319 } 320 compute()321 Result compute() { 322 ANeuralNetworksEvent* event = nullptr; 323 Result result = 324 static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event)); 325 if (result != Result::NO_ERROR) { 326 return result; 327 } 328 // TODO how to manage the lifetime of events when multiple waiters is not 329 // clear. 330 result = static_cast<Result>(ANeuralNetworksEvent_wait(event)); 331 ANeuralNetworksEvent_free(event); 332 return result; 333 } 334 335 private: 336 ANeuralNetworksExecution* mExecution = nullptr; 337 }; 338 339 } // namespace wrapper 340 } // namespace nn 341 } // namespace android 342 343 #endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H 344