• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
18 #define ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
19 
20 #include <android/hardware/neuralnetworks/1.0/IExecutionCallback.h>
21 #include <android/hardware/neuralnetworks/1.0/IPreparedModelCallback.h>
22 #include <android/hardware/neuralnetworks/1.2/IExecutionCallback.h>
23 #include <android/hardware/neuralnetworks/1.2/IPreparedModelCallback.h>
24 #include <hidl/MQDescriptor.h>
25 #include <hidl/Status.h>
26 #include <chrono>
27 #include <condition_variable>
28 #include <functional>
29 #include <mutex>
30 #include <thread>
31 
32 namespace android {
33 namespace hardware {
34 namespace neuralnetworks {
35 namespace V1_2 {
36 namespace implementation {
37 
38 using V1_0::ErrorStatus;
39 
40 /**
41  * The CallbackBase class is used internally by the NeuralNetworks runtime to
42  * synchronize between different threads. An asynchronous task is launched
43  * paired with a callback object. When a client thread requires the output being
44  * generated by the asynchronous task, the client thread can wait for the result
45  * and be blocked until it has completed or a timeout condition has been
46  * reached. Any wait* may safely be called concurrently, even on the same
47  * callback object. When the asynchronous task has finished its workload, it
48  * must immediately call "notify". If the asynchronous task has failed to launch,
49  * the function that tried to launch the asynchronous task must immediately call
50  * "notify". This "notify" call awakens any client threads waiting on the
51  * callback object.
52  *
53  * The CallbackBase class implements some of the base synchronization common to
54  * both PrepareModelCallback and ExecutionCallback. For consistency, any HIDL
55  * callback class must inherit from CallbackBase as well as the HIDL callback
56  * interface it implements.
57  *
58  * This class exists to enable synchronization across HIDL. When synchronization
59  * is only required in the same process, consider using std::future, std::mutex,
60  * std::condition_variable, or std::experimental::latch instead.
61  */
62 class CallbackBase {
63  public:
64     CallbackBase();
65     ~CallbackBase();
66 
67     /**
68      * CallbackBase::wait blocks until notify has been called on the callback
69      * object.
70      */
71     void wait();
72 
73     /**
74      * CallbackBase::wait_for blocks until notify has been called on the
75      * callback object or the time duration from the time the wait_for function
76      * was called has expired, whichever comes first.
77      *
78      * @return Status std::cv_status::no_timeout if the callback was notified
79      *                before the time duration expired, std::cv_status::timeout
80      *                otherwise.
81      */
82     template<class Rep, class Period>
83     std::cv_status wait_for(const std::chrono::duration<Rep,Period>& timeout_duration);
84 
85     /**
86      * CallbackBase::on_finish binds a function to the callback object. This
87      * bound function will be executed when CallbackBase::notify is called,
88      * before any calls to wait* return. (Note that CallbackBase::wait_for can
89      * return std::cv_status::timeout before CallbackBase::notify is called for
90      * the first time, and hence before the bound function is executed.)
91      *
92      * The bound function must not synchronize with or otherwise access the
93      * callback object it is bound to, as this could cause a deadlock.
94      *
95      * CallbackBase::on_finish can be called at most once on a given callback
96      * object, and the call to CallbackBase::on_finish must finish before
97      * CallbackBase::notify is called.
98      *
99      * @param post_work Function to be invoked the first time
100      *                  CallbackBase::notify is called. Must have a target --
101      *                  i.e., must not compare equal to nullptr. post_work
102      *                  returns true if it successfully completes, false if it
103      *                  fails.
104      * @return bool True if the function was successfully bound, false if
105      *              unsuccessful.
106      *
107      * TODO: Why does the return value of the callback matter?
108      */
109     bool on_finish(std::function<bool(void)> post_work);
110 
111     /**
112      * CallbackBase::bind_thread binds a thread to the event for later use by
113      * CallbackBase::join_thread.
114      *
115      * The thread must be passed using std::move.
116      *
117      * Once a thread is bound with CallbackBase::bind_thread, the client code
118      * should ensure that one of the following occurs before the event is
119      * destroyed:
120      * - CallbackBase::join_thread has been called.
121      * - CallbackBase::wait has been called.
122      * - CallbackBase::wait_for has been called and returned other than
123      *   std::cv_status::no_timeout.
124      *
125      * The bound thread shall not call any CallbackBase method with the
126      * exception of CallbackBase::notify, which it must call when the thread has
127      * finished its computation.
128      *
129      * CallbackBase::bind_thread can be called at most once on a given callback
130      * object.
131      *
132      * @param asyncThread Thread to be bound to the callback object. The thread
133      *                    object must represent a thread of execution -- i.e.,
134      *                    asyncThread.joinable() must be true.
135      * @return bool True if successful, false if thread was not properly bound.
136      */
137     bool bind_thread(std::thread&& asyncThread);
138 
139     /**
140      * CallbackBase::join_thread ensures that the thread (if any) bound to this
141      * event with CallbackBase::bind_thread has fully finished and cleaned its
142      * resources. It is legal to call this function multiple times, concurrently
143      * or sequentially.
144      */
145     void join_thread();
146 
147  protected:
148     /**
149      * CallbackBase::notify enables all prior and future wait* calls on the
150      * callback object to proceed. The call to CallbackBase::notify happens
151      * before any wait* calls on this callback object return (except in the case
152      * of wait_for timing out). The asynchronous call the callback object is
153      * paired with must ensure that any update to state that should be visible
154      * to the caller of wait* happens before the call to CallbackBase::notify.
155      *
156      * CallbackBase::notify must be called exactly once on a given callback
157      * object.
158      */
159     void notify();
160 
161  private:
162     // Same as CallbackBase::join_thread but assumes we already hold a lock on
163     // mMutex.
164     void join_thread_locked();
165 
166     bool                      mNotified;
167     std::mutex                mMutex;
168     std::condition_variable   mCondition;
169     std::function<bool(void)> mPostWork;
170     std::thread               mThread;
171 };
172 
173 /**
174  * The PreparedModelCallback class is used to receive the error status of
175  * preparing a model as well as the prepared model from a task executing
176  * asynchronously with respect to the runtime. If a calling thread calls wait*
177  * or get* on a PreparedModelCallback object and the corresponding asynchronous
178  * task has not finished preparing the model, the calling thread will block
179  * until the asynchronous task has either called notify or notify_1_2. For more
180  * information on the synchronization behavior, refer to the CallbackBase class.
181  *
182  * This class inherits the basic blocking and signaling calls from
183  * CallbackBase, and implements the HIDL notify and notify_1_2 calls from
184  * IPreparedModelCallback. This callback object is passed as an argument to
185  * IDevice::prepareModel.
186  */
187 class PreparedModelCallback : public CallbackBase, public IPreparedModelCallback {
188  public:
189     PreparedModelCallback();
190     ~PreparedModelCallback() override;
191 
192     /**
193      * IPreparedModelCallback::notify and IPreparedModelCallback::notify_1_2
194      * mark the callback object with the return status of the asynchronous
195      * model preparation along with the prepared model, and call
196      * CallbackBase::notify, enabling all prior and future wait* calls on the
197      * PreparedModelCallback object to proceed. For more information on the
198      * synchronization behavior, refer to the CallbackBase class.
199      *
200      * Either IPreparedModelCallback::notify or IPreparedModelCallback::notify_1_2
201      * must be called exactly once on a given PreparedModelCallback object.
202      *
203      * @param status Error status returned from asynchronously preparing the
204      *               model; will be:
205      *               - NONE if the asynchronous preparation was successful
206      *               - DEVICE_UNAVAILABLE if driver is offline or busy
207      *               - GENERAL_FAILURE if there is an unspecified error
208      *               - INVALID_ARGUMENT if the input model is invalid
209      * @param preparedModel Returned model that has been prepared for execution,
210      *                      nullptr if the model was unable to be prepared.
211      */
212     Return<void> notify(ErrorStatus status, const sp<V1_0::IPreparedModel>& preparedModel) override;
213     Return<void> notify_1_2(ErrorStatus status,
214                             const sp<V1_2::IPreparedModel>& preparedModel) override;
215 
216     /**
217      * Retrieves the error status returned from the asynchronous task launched
218      * by IDevice::prepareModel. If IDevice::prepareModel has not finished
219      * asynchronously preparing the model, this call will block until the
220      * asynchronous task notifies the object.
221      *
222      * @return status Error status returned from asynchronously preparing the
223      *                model; will be:
224      *                - NONE if the asynchronous preparation was successful
225      *                - DEVICE_UNAVAILABLE if driver is offline or busy
226      *                - GENERAL_FAILURE if there is an unspecified error
227      *                - INVALID_ARGUMENT if the input model is invalid
228      */
229     ErrorStatus getStatus();
230 
231     /**
232      * Retrieves the model that has been prepared for execution from the
233      * asynchronous task launched by IDevice::prepareModel. If
234      * IDevice::prepareModel has not finished asynchronously preparing the
235      * model, this call will block until the asynchronous task notifies the
236      * object.
237      *
238      * @return preparedModel Returned model that has been prepared for
239      *                       execution, nullptr if the model was unable to be
240      *                       prepared.
241      */
242     sp<V1_0::IPreparedModel> getPreparedModel();
243 
244    private:
245     ErrorStatus        mErrorStatus;
246     sp<V1_0::IPreparedModel> mPreparedModel;
247 };
248 
249 /**
250  * The ExecutionCallback class is used to receive the error status of the
251  * execution from a task executing asynchronously with respect to the runtime.
252  * If a calling thread calls wait* or get* on a PreparedModelCallback object and
253  * the corresponding asynchronous task has not finished the execution, the
254  * calling thread will block until the asynchronous task has either called notify
255  * or notify_1_2. For more information on the synchronization behavior, refer to
256  * the CallbackBase class.
257  *
258  * This class inherits the basic blocking and signaling calls from
259  * CallbackBase, and implements the HIDL notify and notify_1_2 calls from
260  * IExecutionCallback. This callback object is passed as an argument to
261  * IPreparedModel::execute.
262  */
263 class ExecutionCallback : public CallbackBase,  public IExecutionCallback {
264  public:
265     ExecutionCallback();
266     ~ExecutionCallback() override;
267 
268     /**
269      * IExecutionCallback::notify and IExecutionCallback::notify_1_2 mark the
270      * callback object with the return status of the asynchronous execution that
271      * held this callback and enable all prior and future wait* calls on the
272      * ExecutionCallback object to proceed. For more information on the
273      * synchronization behavior, refer to the CallbackBase class.
274      *
275      * Either IExecutionCallback::notify or IExecutionCallback::notify_1_2 must
276      * be called exactly once on a given ExecutionCallback object.
277      *
278      * @param status Error status returned from launching the asynchronous task
279      *               (if the launch fails) or from the asynchronous task itself
280      *               (if the launch succeeds). Must be:
281      *               - NONE if the asynchronous execution was successful
282      *               - DEVICE_UNAVAILABLE if driver is offline or busy
283      *               - GENERAL_FAILURE if there is an unspecified error
284      *               - OUTPUT_INSUFFICIENT_SIZE if provided output buffer is
285      *                 not large enough to store the resultant values
286      *               - INVALID_ARGUMENT if the input request is invalid
287      */
288     Return<void> notify(ErrorStatus status) override;
289 
290     /**
291      * Similar to IExecutionCallback::notify, but for V1_2::IPreparedModel to
292      * also notify output shapes along with error status.
293      *
294      * @param status Error status returned from launching the asynchronous task
295      *               (if the launch fails) or from the asynchronous task itself
296      *               (if the launch succeeds). Must be:
297      *               - NONE if the asynchronous execution was successful
298      *               - DEVICE_UNAVAILABLE if driver is offline or busy
299      *               - GENERAL_FAILURE if the asynchronous task resulted in an
300      *                 unspecified error
301      *               - OUTPUT_INSUFFICIENT_SIZE if at least one output
302      *                 operand buffer is not large enough to store the
303      *                 corresponding output
304      *               - INVALID_ARGUMENT if one of the input arguments to
305      *                 prepareModel is invalid
306      * @param outputShapes A list of shape information of model output operands.
307      *                     The index into "outputShapes" corresponds to the index
308      *                     of the output operand in the Request outputs vector.
309      *                     outputShapes must be empty unless the status is either
310      *                     NONE or OUTPUT_INSUFFICIENT_SIZE.
311      * @return Timing Duration of execution. Unless MeasureTiming::YES was passed when
312      *                launching the execution and status is NONE, all times must
313      *                be reported as UINT64_MAX. A driver may choose to report
314      *                any time as UINT64_MAX, indicating that particular measurement is
315      *                not available.
316      */
317     Return<void> notify_1_2(ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
318                             const Timing& timing) override;
319 
320     // An overload of the latest notify interface to hide the version from ExecutionBuilder.
notify(ErrorStatus status,const hidl_vec<OutputShape> & outputShapes,const Timing & timing)321     Return<void> notify(ErrorStatus status, const hidl_vec<OutputShape>& outputShapes,
322                         const Timing& timing) {
323         return notify_1_2(status, outputShapes, timing);
324     }
325 
326     /**
327      * Retrieves the error status returned from the asynchronous task launched
328      * by either IPreparedModel::execute or IPreparedModel::execute_1_2. If
329      * IPreparedModel::execute or IPreparedModel::execute_1_2 has not finished
330      * asynchronously executing, this call will block until the asynchronous task
331      * notifies the object.
332      *
333      * @return status Error status returned from launching the asynchronous task
334      *                (if the launch fails) or from the asynchronous task itself
335      *                (if the launch succeeds). Must be:
336      *                - NONE if the asynchronous execution was successful
337      *                - DEVICE_UNAVAILABLE if driver is offline or busy
338      *                - GENERAL_FAILURE if the asynchronous task resulted in an
339      *                  unspecified error
340      *                - OUTPUT_INSUFFICIENT_SIZE if at least one output
341      *                  operand buffer is not large enough to store the
342      *                  corresponding output
343      *                - INVALID_ARGUMENT if one of the input arguments to
344      *                  prepareModel is invalid
345      */
346     ErrorStatus getStatus();
347 
348     /**
349      * Retrieves the output shapes returned from the asynchronous task launched
350      * by IPreparedModel::execute_1_2. If IPreparedModel::execute_1_2 has not finished
351      * asynchronously executing, this call will block until the asynchronous task
352      * notifies the object.
353      *
354      * If the asynchronous task was launched by IPreparedModel::execute, an empty vector
355      * will be returned.
356      *
357      * @return outputShapes A list of shape information of model output operands.
358      *                      The index into "outputShapes" corresponds to the index
359      *                      of the output operand in the Request outputs vector.
360      *                      outputShapes must be empty unless the status is either
361      *                      NONE or OUTPUT_INSUFFICIENT_SIZE.
362      */
363     const std::vector<OutputShape>& getOutputShapes();
364 
365     /**
366      * Retrieves the duration of execution ofthe asynchronous task launched
367      * by IPreparedModel::execute_1_2. If IPreparedModel::execute_1_2 has not finished
368      * asynchronously executing, this call will block until the asynchronous task
369      * notifies the object.
370      *
371      * If the asynchronous task was launched by IPreparedModel::execute, every time
372      * must be UINT64_MAX.
373      *
374      * @return timing Duration of the execution. Every time must be UINT64_MAX unless
375      *                the status is NONE.
376      */
377     Timing getTiming();
378 
379    private:
380     ErrorStatus mErrorStatus = ErrorStatus::GENERAL_FAILURE;
381     std::vector<OutputShape> mOutputShapes = {};
382     Timing mTiming = {};
383 };
384 
385 
386 // template function implementation(s) below this point
387 
388 template<class Rep, class Period>
wait_for(const std::chrono::duration<Rep,Period> & timeout_duration)389 std::cv_status CallbackBase::wait_for(const std::chrono::duration<Rep,Period>& timeout_duration) {
390     std::unique_lock<std::mutex> lock(mMutex);
391     std::cv_status status = mCondition.wait_for(lock, timeout_duration, [this]{return mNotified;});
392     if (status != std::cv_status::timeout) {
393         join_thread_locked();
394     }
395     return status;
396 }
397 
398 }  // namespace implementation
399 }  // namespace V1_2
400 }  // namespace neuralnetworks
401 }  // namespace hardware
402 }  // namespace android
403 
404 #endif  // ANDROID_HARDWARE_NEURALNETWORKS_V1_0_CALLBACKS_H
405