1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_CALLBACK_H 18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_CALLBACK_H 19 20 #include <android-base/thread_annotations.h> 21 #include <nnapi/Types.h> 22 23 #include <condition_variable> 24 #include <functional> 25 #include <mutex> 26 #include <thread> 27 #include <vector> 28 29 namespace android::nn { 30 31 // This class used to be a HIDL callback class to receive the results of 32 // IDevice::execute* asynchronously. It's not used for this anymore. 33 // 34 // TODO(b/122316159): Replace ExecutionCallback and CallbackEvent with a new 35 // class like AsyncTaskEvent. 36 /** 37 * The ExecutionCallback class is used to receive the results of the execution 38 * from a task executing asynchronously with respect to the runtime. If a 39 * calling thread calls wait or get* on a ExecutionCallback object and the 40 * corresponding asynchronous task has not finished the execution, the calling 41 * thread will block until the asynchronous task has called one of the notify* 42 * methods. 43 * 44 * If the callback object is notified more than once, only the results of the 45 * first call to notify* are used, and the results from subsequent calls are 46 * discarded. 47 */ 48 class ExecutionCallback { 49 using ExecutionFinish = 50 std::function<ErrorStatus(ErrorStatus, const std::vector<OutputShape>&)>; 51 52 public: 53 /** 54 * ExecutionCallback::notify marks the callback object with the results 55 * (error status, dynamic output shapes, and timing information) of the 56 * asynchronous execution that held this callback and enables all prior and 57 * future wait calls on the ExecutionCallback object to proceed. 58 * 59 * If the callback object is notified more than once, only the results of 60 * the first call to notify* are used, and the results from subsequent calls 61 * are discarded. 62 * 63 * @param status Error status returned from launching the asynchronous task 64 * (if the launch fails) or from the asynchronous task itself (if the 65 * launch succeeds). Must be: 66 * - NONE if the asynchronous execution was successful 67 * - DEVICE_UNAVAILABLE if driver is offline or busy 68 * - GENERAL_FAILURE if the asynchronous task resulted in an unspecified 69 * error 70 * - OUTPUT_INSUFFICIENT_SIZE if at least one output operand buffer is 71 * not large enough to store the corresponding output 72 * - INVALID_ARGUMENT if one of the input arguments to prepareModel is 73 * invalid 74 * - MISSED_DEADLINE_* if the deadline could not be met 75 * - RESOURCE_EXHAUSTED_* if the execution was aborted by the driver 76 * @param outputShapes A list of shape information of model output operands. 77 * The index into "outputShapes" corresponds to the index of the output 78 * operand in the Request outputs vector. outputShapes must be empty 79 * unless the status is either NONE or OUTPUT_INSUFFICIENT_SIZE. 80 * @param Timing Duration of execution. Unless MeasureTiming::YES was passed 81 * when launching the execution and status is NONE, all times must be 82 * reported as UINT64_MAX. A driver may choose to report any time as 83 * UINT64_MAX, indicating that particular measurement is not available. 84 */ 85 void notify(ErrorStatus status, const std::vector<OutputShape>& outputShapes, 86 const Timing& timing); 87 88 /** 89 * ExecutionCallback::wait blocks until notify* has been called on the 90 * callback object. 91 */ 92 void wait() const; 93 94 /** 95 * Retrieves the error status returned from the asynchronous task launched 96 * by IPreparedModel::execute* (but not by 97 * IPreparedModel::executeSynchronously*). If IPreparedModel::execute* has 98 * not finished asynchronously executing, this call will block until the 99 * asynchronous task notifies the object. 100 * 101 * @return status Error status returned from launching the asynchronous task 102 * (if the launch fails) or from the asynchronous task itself (if the 103 * launch succeeds). Must be: 104 * - NONE if the asynchronous execution was successful 105 * - DEVICE_UNAVAILABLE if driver is offline or busy 106 * - GENERAL_FAILURE if the asynchronous task resulted in an unspecified 107 * error 108 * - OUTPUT_INSUFFICIENT_SIZE if at least one output operand buffer is 109 * not large enough to store the corresponding output 110 * - INVALID_ARGUMENT if one of the input arguments to prepareModel is 111 * invalid 112 * - MISSED_DEADLINE_* if the deadline could not be met 113 * - RESOURCE_EXHAUSTED_* if the task was aborted by the driver 114 * - DEAD_OBJECT if the driver crashed without returning a result 115 */ 116 ErrorStatus getStatus() const; 117 118 /** 119 * Retrieves the output shapes returned from the asynchronous task launched 120 * by either IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3. If 121 * IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3 has not 122 * finished asynchronously executing, this call will block until the 123 * asynchronous task notifies the object. 124 * 125 * If the asynchronous task was launched by IPreparedModel::execute, an 126 * empty vector will be returned. 127 * 128 * @return outputShapes A list of shape information of model output 129 * operands. The index into "outputShapes" corresponds to the index of 130 * the output operand in the Request outputs vector. outputShapes must 131 * be empty unless the status is either NONE or 132 * OUTPUT_INSUFFICIENT_SIZE. outputShaps may be empty if the status is 133 * NONE and all model output operands are fully-specified at execution 134 * time. outputShapes must have the same number of elements as the 135 * number of model output operands if the status is 136 * OUTPUT_INSUFFICIENT_SIZE, or if the status is NONE and the model has 137 * at least one output operand that is not fully-specified. 138 */ 139 const std::vector<OutputShape>& getOutputShapes() const; 140 141 /** 142 * Retrieves the duration of execution of the asynchronous task launched by 143 * by either IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3. If 144 * IPreparedModel::execute_1_2 or IPreparedModel::execute_1_3 has not 145 * finished asynchronously executing, this call will block until the 146 * asynchronous task notifies the object. 147 * 148 * If the asynchronous task was launched by IPreparedModel::execute, every 149 * time must be UINT64_MAX. 150 * 151 * @return timing Duration of the execution. Every time must be UINT64_MAX 152 * unless the status is NONE. 153 */ 154 Timing getTiming() const; 155 156 /** 157 * ExecutionCallback::bindThread binds a thread to the ExecutionCallback 158 * object. The bound thread is later joined by ExecutionCallback::wait or 159 * ExecutionCallback::get*. 160 * 161 * Once a thread is bound with ExecutionCallback::bindThread, the client 162 * code must ensure that ExecutionCallback::wait or ExecutionCallback::get* 163 * has been called before the ExecutionCallback object is destroyed. 164 * 165 * The bound thread must not call any ExecutionCallback method with the 166 * exception of ExecutionCallback::notify*, which it must call when the 167 * thread has finished its computation. 168 * 169 * ExecutionCallback::bindThread can be called at most once on a given 170 * callback object. 171 * 172 * @param asyncThread Thread to be bound to the callback object. The thread 173 * object must represent a thread of execution -- i.e., 174 * std::thread::joinable() must be true. 175 * @return bool True if successful, false if thread was not properly bound. 176 */ 177 bool bindThread(std::thread asyncThread); 178 179 /** 180 * ExecutionCallback::setOnFinish binds a callback to the ExecutionCallback 181 * object that will be executed during one of the ExecutionCallback::notify* 182 * calls but before any calls to wait or get* return. This provided callback 183 * is provided with both the ErrorStatus and the output shapes from 184 * ExecutionCallback::notify*. 185 * 186 * The bound function must not synchronize with or otherwise access the 187 * callback object it is bound to, as this could cause a deadlock. 188 * 189 * This call will not bind the provided callback if any of the following 190 * occur: 191 * (1) the provided callback is invalid (i.e., "(bool) finish" is false) 192 * (2) ExecutionCallback already contains a bound callback 193 * (3) ExecutionCallback has already been notified with results 194 * 195 * @param finish Callback to be executed when ExecutionCallback is notified 196 * with results. 197 */ 198 void setOnFinish(const ExecutionFinish& finish); 199 200 private: 201 /* 202 * ExecutionCallback::notifyInternal stores the results of the execution 203 * (status, output shapes, and timing information) in the ExecutionCallback 204 * object and invokes the bound callback function "mOnFinish" (if present) 205 * before any call to wait or get* return. It then enables all prior and 206 * future wait calls on the ExecutionCallback object to proceed. 207 */ 208 void notifyInternal(ErrorStatus errorStatus, std::vector<OutputShape> outputShapes, 209 Timing timing); 210 211 // members 212 mutable std::mutex mMutex; 213 mutable std::condition_variable mCondition; 214 mutable std::thread mThread GUARDED_BY(mMutex); 215 ExecutionFinish mOnFinish GUARDED_BY(mMutex); 216 bool mNotified GUARDED_BY(mMutex) = false; 217 ErrorStatus mErrorStatus = ErrorStatus::GENERAL_FAILURE; 218 std::vector<OutputShape> mOutputShapes; 219 Timing mTiming = {}; 220 }; 221 222 } // namespace android::nn 223 224 #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_EXECUTION_CALLBACK_H 225