1 #ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
2 #define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
3
4 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
5 #include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
6 #include <chrono>
7 #include <condition_variable>
8 #include <functional>
9 #include <hidl/MQDescriptor.h>
10 #include <hidl/Status.h>
11 #include <mutex>
12 #include <thread>
13
14 namespace android {
15 namespace hardware {
16 namespace neuralnetworks {
17 namespace V1_0 {
18 namespace implementation {
19
20 using ::android::hardware::hidl_array;
21 using ::android::hardware::hidl_memory;
22 using ::android::hardware::hidl_string;
23 using ::android::hardware::hidl_vec;
24 using ::android::hardware::Return;
25 using ::android::hardware::Void;
26 using ::android::sp;
27
28 /**
29 * The CallbackBase class is used internally by the NeuralNetworks runtime to
30 * synchronize between different threads. An asynchronous task is launched
31 * paired with a callback object. When a client thread requires the output being
32 * generated by the asynchronous task, the client thread can wait for the result
33 * and be blocked until it has completed or a timeout condition has been
34 * reached. Any wait* may safely be called concurrently, even on the same
35 * callback object. When the asynchronous task has finished its workload, it
36 * must immediately call "notify". If the asynchronous task has failed to launch,
37 * the function that tried to launch the asynchronous task must immediately call
38 * "notify". This "notify" call awakens any client threads waiting on the
39 * callback object.
40 *
41 * callback object. When the asynchronous task has finished its workload or has
42 * failed to launch, it must immediately call "notify", awakening any client
43 * threads waiting on the callback object.
44 *
45 * The CallbackBase class implements some of the base synchronization common to
46 * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
47 * callback class must inherit from CallbackBase as well as the HIDL callback
48 * interface it implements.
49 *
50 * This class exists to enable synchronization across HIDL. When synchronization
51 * is only required in the same process, consider using std::future, std::mutex,
52 * std::condition_variable, or std::experimental::latch instead.
53 */
54 class CallbackBase {
55 public:
56 CallbackBase();
57 ~CallbackBase();
58
59 /**
60 * CallbackBase::wait blocks until notify has been called on the callback
61 * object.
62 */
63 void wait();
64
65 /**
66 * CallbackBase::wait_for blocks until notify has been called on the
67 * callback object or the time duration from the time the wait_for function
68 * was called has expired, whichever comes first.
69 *
70 * @return Status std::cv_status::no_timeout if the callback was notified
71 * before the time duration expired, std::cv_status::timeout
72 * otherwise.
73 */
74 template<class Rep, class Period>
75 std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);
76
77 /**
78 * CallbackBase::on_finish binds a function to the callback object. This
79 * bound function will be executed when CallbackBase::notify is called,
80 * before any calls to wait* return. (Note that CallbackBase::wait_for can
81 * return std::cv_status::timeout before CallbackBase::notify is called for
82 * the first time, and hence before the bound function is executed.)
83 *
84 * The bound function must not synchronize with or otherwise access the
85 * callback object it is bound to, as this could cause a deadlock.
86 *
87 * CallbackBase::on_finish can be called at most once on a given callback
88 * object, and the call to CallbackBase::on_finish must finish before
89 * CallbackBase::notify is called.
90 *
91 * @param post_work Function to be invoked the first time
92 * CallbackBase::notify is called. Must have a target --
93 * i.e., must not compare equal to nullptr. post_work
94 * returns true if it successfully completes, false if it
95 * fails.
96 * @return bool True if the function was successfully bound, false if
97 * unsuccessful.
98 *
99 * TODO: Why does the return value of the callback matter?
100 */
101 bool on_finish(std::function<bool(void)> post_work);
102
103 /**
104 * CallbackBase::bind_thread binds a thread to the event for later use by
105 * CallbackBase::join_thread.
106 *
107 * The thread must be passed using std::move.
108 *
109 * Once a thread is bound with CallbackBase::bind_thread, the client code
110 * should ensure that one of the following occurs before the event is
111 * destroyed:
112 * - CallbackBase::join_thread has been called.
113 * - CallbackBase::wait has been called.
114 * - CallbackBase::wait_for has been called and returned other than
115 * std::cv_status::no_timeout.
116 *
117 * The bound thread shall not call any CallbackBase method with the
118 * exception of CallbackBase::notify, which it must call when the thread has
119 * finished its computation.
120 *
121 * CallbackBase::bind_thread can be called at most once on a given callback
122 * object.
123 *
124 * @param asyncThread Thread to be bound to the callback object. The thread
125 * object must represent a thread of execution -- i.e.,
126 * asyncThread.joinable() must be true.
127 * @return bool True if successful, false if thread was not properly bound.
128 */
129 bool bind_thread(std::thread&& asyncThread);
130
131 /**
132 * CallbackBase::join_thread ensures that the thread (if any) bound to this
133 * event with CallbackBase::bind_thread has fully finished and cleaned its
134 * resources. It is legal to call this function multiple times, concurrently
135 * or sequentially.
136 */
137 void join_thread();
138
139 protected:
140 /**
141 * CallbackBase::notify enables all prior and future wait* calls on the
142 * callback object to proceed. The call to CallbackBase::notify happens
143 * before any wait* calls on this callback object return (except in the case
144 * of wait_for timing out). The asynchronous call the callback object is
145 * paired with must ensure that any update to state that should be visible
146 * to the caller of wait* happens before the call to CallbackBase::notify.
147 *
148 * CallbackBase::notify must be called exactly once on a given callback
149 * object.
150 */
151 void notify();
152
153 private:
154 // Same as CallbackBase::join_thread but assumes we already hold a lock on
155 // mMutex.
156 void join_thread_locked();
157
158 bool mNotified;
159 std::mutex mMutex;
160 std::condition_variable mCondition;
161 std::function<bool(void)> mPostWork;
162 std::thread mThread;
163 };
164
165 /**
166 * The PreparedModelCallback class is used to receive the error status of
167 * preparing a model as well as the prepared model from a task executing
168 * asynchronously with respect to the runtime. If a calling thread calls wait*
169 * or get* on a PreparedModelCallback object and the corresponding asynchronous
170 * task has not finished preparing the model, the calling thread will block
171 * until the asynchronous task has called notify. For more information on the
172 * synchronization behavior, refer to the CallbackBase class.
173 *
174 * This class inherits the basic blocking and signaling calls from
175 * CallbackBase, and implements the HIDL notify call from
176 * IPreparedModelCallback. This callback object is passed as an argument to
177 * IDevice::prepareModel.
178 */
179 class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
180 public:
181 PreparedModelCallback();
182 ~PreparedModelCallback() override;
183
184 /**
185 * IPreparedModelCallback::notify marks the callback object with the return
186 * status of the asynchronous model preparation along with the prepared
187 * model, and calls CallbackBase::notify, enabling all prior and future
188 * wait* calls on the PreparedModelCallback object to proceed. For more
189 * information on the synchronization behavior, refer to the CallbackBase
190 * class.
191 *
192 * IPreparedModelCallback::notify must be called exactly once on a given
193 * PreparedModelCallback object.
194 *
195 * @param status Error status returned from asynchronously preparing the
196 * model; will be:
197 * - NONE if the asynchronous preparation was successful
198 * - DEVICE_UNAVAILABLE if driver is offline or busy
199 * - GENERAL_FAILURE if there is an unspecified error
200 * - INVALID_ARGUMENT if the input model is invalid
201 * @param preparedModel Returned model that has been prepared for execution,
202 * nullptr if the model was unable to be prepared.
203 */
204 Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;
205
206 /**
207 * Retrieves the error status returned from the asynchronous task launched
208 * by IDevice::prepareModel. If IDevice::prepareModel has not finished
209 * asynchronously preparing the model, this call will block until the
210 * asynchronous task notifies the object.
211 *
212 * @return status Error status returned from asynchronously preparing the
213 * model; will be:
214 * - NONE if the asynchronous preparation was successful
215 * - DEVICE_UNAVAILABLE if driver is offline or busy
216 * - GENERAL_FAILURE if there is an unspecified error
217 * - INVALID_ARGUMENT if the input model is invalid
218 */
219 ErrorStatus getStatus();
220
221 /**
222 * Retrieves the model that has been prepared for execution from the
223 * asynchronous task launched by IDevice::prepareModel. If
224 * IDevice::prepareModel has not finished asynchronously preparing the
225 * model, this call will block until the asynchronous task notifies the
226 * object.
227 *
228 * @return preparedModel Returned model that has been prepared for
229 * execution, nullptr if the model was unable to be
230 * prepared.
231 */
232 sp<IPreparedModel> getPreparedModel();
233
234 private:
235 ErrorStatus mErrorStatus;
236 sp<IPreparedModel> mPreparedModel;
237 };
238
239 /**
240 * The ExecutionCallback class is used to receive the error status of the
241 * execution from a task executing asynchronously with respect to the runtime.
242 * If a calling thread calls wait* or get* on a PreparedModelCallback object and
243 * the corresponding asynchronous task has not finished the execution, the
244 * calling thread will block until the asynchronous task has called notify. For
245 * more information on the synchronization behavior, refer to the CallbackBase
246 * class.
247 *
248 * This class inherits the basic blocking and signaling calls from
249 * CallbackBase, and implements the HIDL notify call from
250 * IExecutionCallback. This callback object is passed as an argument to
251 * IPreparedModel::execute.
252 */
253 class ExecutionCallback : public CallbackBase, public IExecutionCallback {
254 public:
255 ExecutionCallback();
256 ~ExecutionCallback() override;
257
258 /**
259 * IExecutionCallback::notify marks the callback object with the return
260 * status of the asynchronous execution that held this callback and enables
261 * all prior and future wait* calls on the ExecutionCallback object to
262 * proceed. For more information on the synchronization behavior, refer to
263 * the CallbackBase class.
264 *
265 * IExecutionCallback::notify must be called exactly once on a given
266 * ExecutionCallback object.
267 *
268 * @param status Error status returned from asynchronously preparing the
269 * model; will be:
270 * - NONE if the asynchronous execution was successful
271 * - DEVICE_UNAVAILABLE if driver is offline or busy
272 * - GENERAL_FAILURE if there is an unspecified error
273 * - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
274 * not large enough to store the resultant values
275 * - INVALID_ARGUMENT if the input request is invalid
276 */
277 Return<void> notify(ErrorStatus status) override;
278
279 /**
280 * Retrieves the error status returned from the asynchronous task launched
281 * by IPreparedModel::execute. If IPreparedModel::execute has not finished
282 * asynchronously executing, this call will block until the asynchronous task
283 * notifies the object.
284 *
285 * @return status Error status returned from asynchronously preparing the
286 * model; will be:
287 * - NONE if the asynchronous execution was successful
288 * - DEVICE_UNAVAILABLE if driver is offline or busy
289 * - GENERAL_FAILURE if there is an unspecified error
290 * - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
291 * not large enough to store the resultant values
292 * - INVALID_ARGUMENT if the input request is invalid
293 */
294 ErrorStatus getStatus();
295
296 private:
297 ErrorStatus mErrorStatus;
298 };
299
300
301 // template function implementation(s) below this point
302
303 template<class Rep, class Period>
wait_for(const std::chrono::duration<Rep,Period> & timeout_duration)304 std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
305 std::unique_lock<std::mutex> lock(mMutex);
306 std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
307 if (status != std::cv_status::timeout) {
308 join_thread_locked();
309 }
310 return status;
311 }
312
313 } // namespace implementation
314 } // namespace V1_0
315 } // namespace neuralnetworks
316 } // namespace hardware
317 } // namespace android
318
319 #endif // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
320