• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
2 #define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
3 
4 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
5 #include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
6 #include <chrono>
7 #include <condition_variable>
8 #include <functional>
9 #include <hidl/MQDescriptor.h>
10 #include <hidl/Status.h>
11 #include <mutex>
12 #include <thread>
13 
14 namespace android {
15 namespace hardware {
16 namespace neuralnetworks {
17 namespace V1_0 {
18 namespace implementation {
19 
20 using ::android::hardware::hidl_array;
21 using ::android::hardware::hidl_memory;
22 using ::android::hardware::hidl_string;
23 using ::android::hardware::hidl_vec;
24 using ::android::hardware::Return;
25 using ::android::hardware::Void;
26 using ::android::sp;
27 
28 /**
29  * The CallbackBase class is used internally by the NeuralNetworks runtime to
30  * synchronize between different threads. An asynchronous task is launched
31  * paired with a callback object. When a client thread requires the output being
32  * generated by the asynchronous task, the client thread can wait for the result
33  * and be blocked until it has completed or a timeout condition has been
34  * reached. Any wait* may safely be called concurrently, even on the same
35  * callback object. When the asynchronous task has finished its workload, it
36  * must immediately call "notify". If the asynchronous task has failed to launch,
37  * the function that tried to launch the asynchronous task must immediately call
38  * "notify". This "notify" call awakens any client threads waiting on the
39  * callback object.
40  *
41  * callback object. When the asynchronous task has finished its workload or has
42  * failed to launch, it must immediately call "notify", awakening any client
43  * threads waiting on the callback object.
44  *
45  * The CallbackBase class implements some of the base synchronization common to
46  * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
47  * callback class must inherit from CallbackBase as well as the HIDL callback
48  * interface it implements.
49  *
50  * This class exists to enable synchronization across HIDL. When synchronization
51  * is only required in the same process, consider using std::future, std::mutex,
52  * std::condition_variable, or std::experimental::latch instead.
53  */
54 class CallbackBase {
55  public:
56     CallbackBase();
57     ~CallbackBase();
58 
59     /**
60      * CallbackBase::wait blocks until notify has been called on the callback
61      * object.
62      */
63     void wait();
64 
65     /**
66      * CallbackBase::wait_for blocks until notify has been called on the
67      * callback object or the time duration from the time the wait_for function
68      * was called has expired, whichever comes first.
69      *
70      * @return Status std::cv_status::no_timeout if the callback was notified
71      *                before the time duration expired, std::cv_status::timeout
72      *                otherwise.
73      */
74     template<class Rep, class Period>
75     std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);
76 
77     /**
78      * CallbackBase::on_finish binds a function to the callback object. This
79      * bound function will be executed when CallbackBase::notify is called,
80      * before any calls to wait* return. (Note that CallbackBase::wait_for can
81      * return std::cv_status::timeout before CallbackBase::notify is called for
82      * the first time, and hence before the bound function is executed.)
83      *
84      * The bound function must not synchronize with or otherwise access the
85      * callback object it is bound to, as this could cause a deadlock.
86      *
87      * CallbackBase::on_finish can be called at most once on a given callback
88      * object, and the call to CallbackBase::on_finish must finish before
89      * CallbackBase::notify is called.
90      *
91      * @param post_work Function to be invoked the first time
92      *                  CallbackBase::notify is called. Must have a target --
93      *                  i.e., must not compare equal to nullptr. post_work
94      *                  returns true if it successfully completes, false if it
95      *                  fails.
96      * @return bool True if the function was successfully bound, false if
97      *              unsuccessful.
98      *
99      * TODO: Why does the return value of the callback matter?
100      */
101     bool on_finish(std::function<bool(void)> post_work);
102 
103     /**
104      * CallbackBase::bind_thread binds a thread to the event for later use by
105      * CallbackBase::join_thread.
106      *
107      * The thread must be passed using std::move.
108      *
109      * Once a thread is bound with CallbackBase::bind_thread, the client code
110      * should ensure that one of the following occurs before the event is
111      * destroyed:
112      * - CallbackBase::join_thread has been called.
113      * - CallbackBase::wait has been called.
114      * - CallbackBase::wait_for has been called and returned other than
115      *   std::cv_status::no_timeout.
116      *
117      * The bound thread shall not call any CallbackBase method with the
118      * exception of CallbackBase::notify, which it must call when the thread has
119      * finished its computation.
120      *
121      * CallbackBase::bind_thread can be called at most once on a given callback
122      * object.
123      *
124      * @param asyncThread Thread to be bound to the callback object. The thread
125      *                    object must represent a thread of execution -- i.e.,
126      *                    asyncThread.joinable() must be true.
127      * @return bool True if successful, false if thread was not properly bound.
128      */
129     bool bind_thread(std::thread&& asyncThread);
130 
131     /**
132      * CallbackBase::join_thread ensures that the thread (if any) bound to this
133      * event with CallbackBase::bind_thread has fully finished and cleaned its
134      * resources. It is legal to call this function multiple times, concurrently
135      * or sequentially.
136      */
137     void join_thread();
138 
139  protected:
140     /**
141      * CallbackBase::notify enables all prior and future wait* calls on the
142      * callback object to proceed. The call to CallbackBase::notify happens
143      * before any wait* calls on this callback object return (except in the case
144      * of wait_for timing out). The asynchronous call the callback object is
145      * paired with must ensure that any update to state that should be visible
146      * to the caller of wait* happens before the call to CallbackBase::notify.
147      *
148      * CallbackBase::notify must be called exactly once on a given callback
149      * object.
150      */
151     void notify();
152 
153  private:
154     // Same as CallbackBase::join_thread but assumes we already hold a lock on
155     // mMutex.
156     void join_thread_locked();
157 
158     bool                      mNotified;
159     std::mutex                mMutex;
160     std::condition_variable   mCondition;
161     std::function<bool(void)> mPostWork;
162     std::thread               mThread;
163 };
164 
165 /**
166  * The PreparedModelCallback class is used to receive the error status of
167  * preparing a model as well as the prepared model from a task executing
168  * asynchronously with respect to the runtime. If a calling thread calls wait*
169  * or get* on a PreparedModelCallback object and the corresponding asynchronous
170  * task has not finished preparing the model, the calling thread will block
171  * until the asynchronous task has called notify. For more information on the
172  * synchronization behavior, refer to the CallbackBase class.
173  *
174  * This class inherits the basic blocking and signaling calls from
175  * CallbackBase, and implements the HIDL notify call from
176  * IPreparedModelCallback. This callback object is passed as an argument to
177  * IDevice::prepareModel.
178  */
179 class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
180  public:
181     PreparedModelCallback();
182     ~PreparedModelCallback() override;
183 
184     /**
185      * IPreparedModelCallback::notify marks the callback object with the return
186      * status of the asynchronous model preparation along with the prepared
187      * model, and calls CallbackBase::notify, enabling all prior and future
188      * wait* calls on the PreparedModelCallback object to proceed. For more
189      * information on the synchronization behavior, refer to the CallbackBase
190      * class.
191      *
192      * IPreparedModelCallback::notify must be called exactly once on a given
193      * PreparedModelCallback object.
194      *
195      * @param status Error status returned from asynchronously preparing the
196      *               model; will be:
197      *               - NONE if the asynchronous preparation was successful
198      *               - DEVICE_UNAVAILABLE if driver is offline or busy
199      *               - GENERAL_FAILURE if there is an unspecified error
200      *               - INVALID_ARGUMENT if the input model is invalid
201      * @param preparedModel Returned model that has been prepared for execution,
202      *                      nullptr if the model was unable to be prepared.
203      */
204     Return<void> notify(ErrorStatus status, const sp<IPreparedModel>& preparedModel) override;
205 
206     /**
207      * Retrieves the error status returned from the asynchronous task launched
208      * by IDevice::prepareModel. If IDevice::prepareModel has not finished
209      * asynchronously preparing the model, this call will block until the
210      * asynchronous task notifies the object.
211      *
212      * @return status Error status returned from asynchronously preparing the
213      *                model; will be:
214      *                - NONE if the asynchronous preparation was successful
215      *                - DEVICE_UNAVAILABLE if driver is offline or busy
216      *                - GENERAL_FAILURE if there is an unspecified error
217      *                - INVALID_ARGUMENT if the input model is invalid
218      */
219     ErrorStatus getStatus();
220 
221     /**
222      * Retrieves the model that has been prepared for execution from the
223      * asynchronous task launched by IDevice::prepareModel. If
224      * IDevice::prepareModel has not finished asynchronously preparing the
225      * model, this call will block until the asynchronous task notifies the
226      * object.
227      *
228      * @return preparedModel Returned model that has been prepared for
229      *                       execution, nullptr if the model was unable to be
230      *                       prepared.
231      */
232     sp<IPreparedModel> getPreparedModel();
233 
234  private:
235     ErrorStatus        mErrorStatus;
236     sp<IPreparedModel> mPreparedModel;
237 };
238 
239 /**
240  * The ExecutionCallback class is used to receive the error status of the
241  * execution from a task executing asynchronously with respect to the runtime.
242  * If a calling thread calls wait* or get* on a PreparedModelCallback object and
243  * the corresponding asynchronous task has not finished the execution, the
244  * calling thread will block until the asynchronous task has called notify. For
245  * more information on the synchronization behavior, refer to the CallbackBase
246  * class.
247  *
248  * This class inherits the basic blocking and signaling calls from
249  * CallbackBase, and implements the HIDL notify call from
250  * IExecutionCallback. This callback object is passed as an argument to
251  * IPreparedModel::execute.
252  */
253 class ExecutionCallback : public CallbackBase,  public IExecutionCallback {
254  public:
255     ExecutionCallback();
256     ~ExecutionCallback() override;
257 
258     /**
259      * IExecutionCallback::notify marks the callback object with the return
260      * status of the asynchronous execution that held this callback and enables
261      * all prior and future wait* calls on the ExecutionCallback object to
262      * proceed. For more information on the synchronization behavior, refer to
263      * the CallbackBase class.
264      *
265      * IExecutionCallback::notify must be called exactly once on a given
266      * ExecutionCallback object.
267      *
268      * @param status Error status returned from asynchronously preparing the
269      *               model; will be:
270      *               - NONE if the asynchronous execution was successful
271      *               - DEVICE_UNAVAILABLE if driver is offline or busy
272      *               - GENERAL_FAILURE if there is an unspecified error
273      *               - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
274      *                 not large enough to store the resultant values
275      *               - INVALID_ARGUMENT if the input request is invalid
276      */
277     Return<void> notify(ErrorStatus status) override;
278 
279     /**
280      * Retrieves the error status returned from the asynchronous task launched
281      * by IPreparedModel::execute. If IPreparedModel::execute has not finished
282      * asynchronously executing, this call will block until the asynchronous task
283      * notifies the object.
284      *
285      * @return status Error status returned from asynchronously preparing the
286      *                model; will be:
287      *                - NONE if the asynchronous execution was successful
288      *                - DEVICE_UNAVAILABLE if driver is offline or busy
289      *                - GENERAL_FAILURE if there is an unspecified error
290      *                - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
291      *                  not large enough to store the resultant values
292      *                - INVALID_ARGUMENT if the input request is invalid
293      */
294     ErrorStatus getStatus();
295 
296  private:
297     ErrorStatus mErrorStatus;
298 };
299 
300 
301 // template function implementation(s) below this point
302 
303 template<class Rep, class Period>
wait_for(const std::chrono::duration<Rep,Period> & timeout_duration)304 std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
305     std::unique_lock<std::mutex> lock(mMutex);
306     std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
307     if (status != std::cv_status::timeout) {
308         join_thread_locked();
309     }
310     return status;
311 }
312 
313 }  // namespace implementation
314 }  // namespace V1_0
315 }  // namespace neuralnetworks
316 }  // namespace hardware
317 }  // namespace android
318 
319 #endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
320