1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Provides C++ classes to more easily use the Neural Networks API. 18 // TODO(b/117845862): this should be auto generated from NeuralNetworksWrapper.h. 19 20 #ifndef ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H 21 #define ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H 22 23 #include "NeuralNetworks.h" 24 #include "NeuralNetworksWrapper.h" 25 #include "NeuralNetworksWrapperExtensions.h" 26 27 #include <math.h> 28 #include <optional> 29 #include <string> 30 #include <vector> 31 32 namespace android { 33 namespace nn { 34 namespace test_wrapper { 35 36 using wrapper::Event; 37 using wrapper::ExecutePreference; 38 using wrapper::ExtensionModel; 39 using wrapper::ExtensionOperandParams; 40 using wrapper::ExtensionOperandType; 41 using wrapper::Memory; 42 using wrapper::Model; 43 using wrapper::OperandType; 44 using wrapper::Result; 45 using wrapper::SymmPerChannelQuantParams; 46 using wrapper::Type; 47 48 class Compilation { 49 public: Compilation(const Model * model)50 Compilation(const Model* model) { 51 int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation); 52 if (result != 0) { 53 // TODO Handle the error 54 } 55 } 56 Compilation()57 Compilation() {} 58 ~Compilation()59 ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); } 60 61 // Disallow copy semantics to ensure the runtime object can only be freed 62 // once. Copy semantics could be enabled if some sort of reference counting 63 // or deep-copy system for runtime objects is added later. 64 Compilation(const Compilation&) = delete; 65 Compilation& operator=(const Compilation&) = delete; 66 67 // Move semantics to remove access to the runtime object from the wrapper 68 // object that is being moved. This ensures the runtime object will be 69 // freed only once. Compilation(Compilation && other)70 Compilation(Compilation&& other) { *this = std::move(other); } 71 Compilation& operator=(Compilation&& other) { 72 if (this != &other) { 73 ANeuralNetworksCompilation_free(mCompilation); 74 mCompilation = other.mCompilation; 75 other.mCompilation = nullptr; 76 } 77 return *this; 78 } 79 setPreference(ExecutePreference preference)80 Result setPreference(ExecutePreference preference) { 81 return static_cast<Result>(ANeuralNetworksCompilation_setPreference( 82 mCompilation, static_cast<int32_t>(preference))); 83 } 84 setCaching(const std::string & cacheDir,const std::vector<uint8_t> & token)85 Result setCaching(const std::string& cacheDir, const std::vector<uint8_t>& token) { 86 if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN) { 87 return Result::BAD_DATA; 88 } 89 return static_cast<Result>(ANeuralNetworksCompilation_setCaching( 90 mCompilation, cacheDir.c_str(), token.data())); 91 } 92 finish()93 Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); } 94 getHandle()95 ANeuralNetworksCompilation* getHandle() const { return mCompilation; } 96 97 protected: 98 ANeuralNetworksCompilation* mCompilation = nullptr; 99 }; 100 101 class Execution { 102 public: Execution(const Compilation * compilation)103 Execution(const Compilation* compilation) : mCompilation(compilation->getHandle()) { 104 int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution); 105 if (result != 0) { 106 // TODO Handle the error 107 } 108 } 109 ~Execution()110 ~Execution() { ANeuralNetworksExecution_free(mExecution); } 111 112 // Disallow copy semantics to ensure the runtime object can only be freed 113 // once. Copy semantics could be enabled if some sort of reference counting 114 // or deep-copy system for runtime objects is added later. 115 Execution(const Execution&) = delete; 116 Execution& operator=(const Execution&) = delete; 117 118 // Move semantics to remove access to the runtime object from the wrapper 119 // object that is being moved. This ensures the runtime object will be 120 // freed only once. Execution(Execution && other)121 Execution(Execution&& other) { *this = std::move(other); } 122 Execution& operator=(Execution&& other) { 123 if (this != &other) { 124 ANeuralNetworksExecution_free(mExecution); 125 mCompilation = other.mCompilation; 126 other.mCompilation = nullptr; 127 mExecution = other.mExecution; 128 other.mExecution = nullptr; 129 } 130 return *this; 131 } 132 133 Result setInput(uint32_t index, const void* buffer, size_t length, 134 const ANeuralNetworksOperandType* type = nullptr) { 135 return static_cast<Result>( 136 ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length)); 137 } 138 139 Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 140 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 141 return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory( 142 mExecution, index, type, memory->get(), offset, length)); 143 } 144 145 Result setOutput(uint32_t index, void* buffer, size_t length, 146 const ANeuralNetworksOperandType* type = nullptr) { 147 return static_cast<Result>( 148 ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length)); 149 } 150 151 Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, 152 uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { 153 return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory( 154 mExecution, index, type, memory->get(), offset, length)); 155 } 156 startCompute(Event * event)157 Result startCompute(Event* event) { 158 ANeuralNetworksEvent* ev = nullptr; 159 Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev)); 160 event->set(ev); 161 return result; 162 } 163 compute()164 Result compute() { 165 switch (mComputeMode) { 166 case ComputeMode::SYNC: { 167 return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution)); 168 } 169 case ComputeMode::ASYNC: { 170 ANeuralNetworksEvent* event = nullptr; 171 Result result = static_cast<Result>( 172 ANeuralNetworksExecution_startCompute(mExecution, &event)); 173 if (result != Result::NO_ERROR) { 174 return result; 175 } 176 // TODO how to manage the lifetime of events when multiple waiters is not 177 // clear. 178 result = static_cast<Result>(ANeuralNetworksEvent_wait(event)); 179 ANeuralNetworksEvent_free(event); 180 return result; 181 } 182 case ComputeMode::BURST: { 183 ANeuralNetworksBurst* burst = nullptr; 184 Result result = 185 static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst)); 186 if (result != Result::NO_ERROR) { 187 return result; 188 } 189 result = static_cast<Result>( 190 ANeuralNetworksExecution_burstCompute(mExecution, burst)); 191 ANeuralNetworksBurst_free(burst); 192 return result; 193 } 194 } 195 return Result::BAD_DATA; 196 } 197 198 // By default, compute() uses the synchronous API. setComputeMode() can be 199 // used to change the behavior of compute() to either: 200 // - use the asynchronous API and then wait for computation to complete 201 // or 202 // - use the burst API 203 // Returns the previous ComputeMode. 204 enum class ComputeMode { SYNC, ASYNC, BURST }; setComputeMode(ComputeMode mode)205 static ComputeMode setComputeMode(ComputeMode mode) { 206 ComputeMode oldComputeMode = mComputeMode; 207 mComputeMode = mode; 208 return oldComputeMode; 209 } 210 getOutputOperandDimensions(uint32_t index,std::vector<uint32_t> * dimensions)211 Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) { 212 uint32_t rank = 0; 213 Result result = static_cast<Result>( 214 ANeuralNetworksExecution_getOutputOperandRank(mExecution, index, &rank)); 215 dimensions->resize(rank); 216 if ((result != Result::NO_ERROR && result != Result::OUTPUT_INSUFFICIENT_SIZE) || 217 rank == 0) { 218 return result; 219 } 220 result = static_cast<Result>(ANeuralNetworksExecution_getOutputOperandDimensions( 221 mExecution, index, dimensions->data())); 222 return result; 223 } 224 225 private: 226 ANeuralNetworksCompilation* mCompilation = nullptr; 227 ANeuralNetworksExecution* mExecution = nullptr; 228 229 // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp. 230 static ComputeMode mComputeMode; 231 }; 232 233 } // namespace test_wrapper 234 } // namespace nn 235 } // namespace android 236 237 #endif // ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H 238