• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1  /*
2   * Copyright (C) 2017 The Android Open Source Project
3   *
4   * Licensed under the Apache License, Version 2.0 (the "License");
5   * you may not use this file except in compliance with the License.
6   * You may obtain a copy of the License at
7   *
8   *      http://www.apache.org/licenses/LICENSE-2.0
9   *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  #define LOG_TAG "Callbacks"
18  
19  #include "Callbacks.h"
20  
21  #include <android-base/logging.h>
22  #include <limits>
23  #include <utility>
24  #include <vector>
25  
26  namespace android::nn {
27  
28  using namespace hal;
29  
30  constexpr Timing kNoTiming = {.timeOnDevice = std::numeric_limits<uint64_t>::max(),
31                                .timeInDriver = std::numeric_limits<uint64_t>::max()};
32  
33  // PreparedModelCallback methods begin here
34  
notifyInternal(bool deadObject,ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModel)35  Return<void> PreparedModelCallback::notifyInternal(bool deadObject, ErrorStatus errorStatus,
36                                                     const sp<V1_0::IPreparedModel>& preparedModel) {
37      {
38          std::lock_guard<std::mutex> hold(mMutex);
39  
40          // quick-return if object has already been notified
41          if (mNotified) {
42              return Void();
43          }
44  
45          // store results and mark as notified
46          mDeadObject = deadObject;
47          mErrorStatus = errorStatus;
48          mPreparedModel = preparedModel;
49          mNotified = true;
50      }
51  
52      mCondition.notify_all();
53      return Void();
54  }
55  
notify(V1_0::ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModel)56  Return<void> PreparedModelCallback::notify(V1_0::ErrorStatus errorStatus,
57                                             const sp<V1_0::IPreparedModel>& preparedModel) {
58      return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), preparedModel);
59  }
60  
notify_1_2(V1_0::ErrorStatus errorStatus,const sp<V1_2::IPreparedModel> & preparedModel)61  Return<void> PreparedModelCallback::notify_1_2(V1_0::ErrorStatus errorStatus,
62                                                 const sp<V1_2::IPreparedModel>& preparedModel) {
63      return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), preparedModel);
64  }
65  
notify_1_3(ErrorStatus errorStatus,const sp<V1_3::IPreparedModel> & preparedModel)66  Return<void> PreparedModelCallback::notify_1_3(ErrorStatus errorStatus,
67                                                 const sp<V1_3::IPreparedModel>& preparedModel) {
68      return notifyInternal(false, errorStatus, preparedModel);
69  }
70  
notifyAsDeadObject()71  void PreparedModelCallback::notifyAsDeadObject() {
72      notifyInternal(true, ErrorStatus::GENERAL_FAILURE, nullptr);
73  }
74  
wait() const75  void PreparedModelCallback::wait() const {
76      std::unique_lock<std::mutex> lock(mMutex);
77      mCondition.wait(lock, [this] { return mNotified; });
78  }
79  
getStatus() const80  ErrorStatus PreparedModelCallback::getStatus() const {
81      wait();
82      return mErrorStatus;
83  }
84  
getPreparedModel() const85  sp<V1_0::IPreparedModel> PreparedModelCallback::getPreparedModel() const {
86      wait();
87      return mPreparedModel;
88  }
89  
isDeadObject() const90  bool PreparedModelCallback::isDeadObject() const {
91      wait();
92      return mDeadObject;
93  }
94  
95  // ExecutionCallback methods begin here
96  
notify(V1_0::ErrorStatus errorStatus)97  Return<void> ExecutionCallback::notify(V1_0::ErrorStatus errorStatus) {
98      return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), {}, kNoTiming);
99  }
100  
notify_1_2(V1_0::ErrorStatus errorStatus,const hidl_vec<OutputShape> & outputShapes,const Timing & timing)101  Return<void> ExecutionCallback::notify_1_2(V1_0::ErrorStatus errorStatus,
102                                             const hidl_vec<OutputShape>& outputShapes,
103                                             const Timing& timing) {
104      return notifyInternal(false, static_cast<ErrorStatus>(errorStatus), outputShapes, timing);
105  }
106  
notify_1_3(V1_3::ErrorStatus errorStatus,const hidl_vec<OutputShape> & outputShapes,const Timing & timing)107  Return<void> ExecutionCallback::notify_1_3(V1_3::ErrorStatus errorStatus,
108                                             const hidl_vec<OutputShape>& outputShapes,
109                                             const Timing& timing) {
110      return notifyInternal(false, errorStatus, outputShapes, timing);
111  }
112  
notifyAsDeadObject()113  void ExecutionCallback::notifyAsDeadObject() {
114      notifyInternal(true, ErrorStatus::GENERAL_FAILURE, {}, kNoTiming);
115  }
116  
wait() const117  void ExecutionCallback::wait() const {
118      std::unique_lock<std::mutex> lock(mMutex);
119      mCondition.wait(lock, [this] { return mNotified; });
120  
121      /*
122       * Note that we cannot call std::thread::join from ExecutionCallback's
123       * destructor: ExecutionCallback is intended to be reference counted, and it
124       * is possible that the reference count drops to zero in the bound thread,
125       * causing the bound thread to call this destructor. If a thread tries to
126       * join itself, it throws an exception, producing a message like the
127       * following:
128       *
129       *     terminating with uncaught exception of type std::__1::system_error:
130       *     thread::join failed: Resource deadlock would occur
131       */
132      if (mThread.joinable()) {
133          mThread.join();
134      }
135  }
136  
getStatus() const137  ErrorStatus ExecutionCallback::getStatus() const {
138      wait();
139      return mErrorStatus;
140  }
141  
getOutputShapes() const142  const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
143      wait();
144      return mOutputShapes;
145  }
146  
getTiming() const147  Timing ExecutionCallback::getTiming() const {
148      wait();
149      return mTiming;
150  }
151  
isDeadObject() const152  bool ExecutionCallback::isDeadObject() const {
153      wait();
154      return mDeadObject;
155  }
156  
bindThread(std::thread asyncThread)157  bool ExecutionCallback::bindThread(std::thread asyncThread) {
158      std::lock_guard<std::mutex> lock(mMutex);
159  
160      // Ensure ExecutionCallback object does not already have a thread bound
161      if (mThread.joinable()) {
162          LOG(ERROR) << "ExecutionCallback::bindThread -- a thread has already been bound to this "
163                        "callback object";
164          return false;
165      }
166  
167      // Ensure the new thread is valid
168      if (!asyncThread.joinable()) {
169          LOG(ERROR) << "ExecutionCallback::bindThread -- the new thread is not joinable";
170          return false;
171      }
172  
173      mThread = std::move(asyncThread);
174      return true;
175  }
176  
setOnFinish(const ExecutionFinish & finish)177  void ExecutionCallback::setOnFinish(const ExecutionFinish& finish) {
178      std::lock_guard<std::mutex> hold(mMutex);
179  
180      // Ensure ExecutionCallback object does not already have a "finish" callback
181      if (mOnFinish != nullptr) {
182          LOG(ERROR) << "ExecutionCallback::setOnFinish -- object already has a \"finish\" callback";
183          return;
184      }
185  
186      // Ensure new "finish" callback is valid
187      if (finish == nullptr) {
188          LOG(ERROR) << "ExecutionCallback::setOnFinish -- \"finish\" callback is invalid";
189          return;
190      }
191  
192      // Essure ExecutionCallback object has not already been notified
193      if (mNotified) {
194          LOG(ERROR) << "ExecutionCallback::setOnFinish -- ExecutionCallback has already been "
195                        "notified with results";
196          return;
197      }
198  
199      mOnFinish = finish;
200  }
201  
notifyInternal(bool deadObject,ErrorStatus errorStatus,std::vector<OutputShape> outputShapes,Timing timing)202  Return<void> ExecutionCallback::notifyInternal(bool deadObject, ErrorStatus errorStatus,
203                                                 std::vector<OutputShape> outputShapes,
204                                                 Timing timing) {
205      // check results
206      if (!deadObject) {
207          if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
208              // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
209              if (outputShapes.size() == 0) {
210                  LOG(ERROR)
211                          << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
212                  errorStatus = ErrorStatus::GENERAL_FAILURE;
213                  outputShapes = {};
214                  timing = kNoTiming;
215              }
216          } else if (errorStatus != ErrorStatus::NONE) {
217              // outputShapes must be empty if errorStatus is neither NONE nor
218              // OUTPUT_INSUFFICIENT_SIZE.
219              if (outputShapes.size() != 0) {
220                  LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
221                                "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
222                  errorStatus = ErrorStatus::GENERAL_FAILURE;
223                  outputShapes = {};
224                  timing = kNoTiming;
225              }
226          }
227      }
228  
229      // store results
230      {
231          std::lock_guard<std::mutex> hold(mMutex);
232  
233          // quick-return if object has already been notified
234          if (mNotified) {
235              return Void();
236          }
237  
238          mDeadObject = deadObject;
239          mErrorStatus = errorStatus;
240          mOutputShapes = std::move(outputShapes);
241          mTiming = timing;
242          mNotified = true;
243  
244          if (mOnFinish != nullptr) {
245              ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
246              mOnFinish = nullptr;
247              if (status != ErrorStatus::NONE) {
248                  mErrorStatus = status;
249              }
250          }
251      }
252      mCondition.notify_all();
253      return Void();
254  }
255  
256  }  // namespace android::nn
257