1 #include "Callbacks.h" 2 #include <android-base/logging.h> 3 4 namespace android { 5 namespace hardware { 6 namespace neuralnetworks { 7 namespace V1_0 { 8 namespace implementation { 9 CallbackBase()10CallbackBase::CallbackBase() : mNotified(false) {} 11 ~CallbackBase()12CallbackBase::~CallbackBase() { 13 // Note that we cannot call CallbackBase::join_thread from here: 14 // CallbackBase is intended to be reference counted, and it is possible that 15 // the reference count drops to zero in the bound thread, causing the 16 // bound thread to call this destructor. If a thread tries to join 17 // itself, it throws an exception, producing a message like the 18 // following: 19 // 20 // terminating with uncaught exception of type std::__1::system_error: 21 // thread::join failed: Resource deadlock would occur 22 } 23 wait()24void CallbackBase::wait() { 25 std::unique_lock<std::mutex> lock(mMutex); 26 mCondition.wait(lock, [this]{return mNotified;}); 27 join_thread_locked(); 28 } 29 on_finish(std::function<bool (void)> post_work)30bool CallbackBase::on_finish(std::function<bool(void)> post_work) { 31 std::lock_guard<std::mutex> lock(mMutex); 32 if (mPostWork != nullptr) { 33 LOG(ERROR) << "CallbackBase::on_finish -- a post-work function has already been bound to " 34 "this callback object"; 35 return false; 36 } 37 if (post_work == nullptr) { 38 LOG(ERROR) << "CallbackBase::on_finish -- the new post-work function is invalid"; 39 return false; 40 } 41 mPostWork = std::move(post_work); 42 return true; 43 } 44 bind_thread(std::thread && asyncThread)45bool CallbackBase::bind_thread(std::thread&& asyncThread) { 46 std::lock_guard<std::mutex> lock(mMutex); 47 if (mThread.joinable()) { 48 LOG(ERROR) << "CallbackBase::bind_thread -- a thread has already been bound to this " 49 "callback object"; 50 return false; 51 } 52 if (!asyncThread.joinable()) { 53 LOG(ERROR) << "CallbackBase::bind_thread -- the new thread is not joinable"; 54 return false; 55 } 56 mThread = std::move(asyncThread); 57 return true; 58 } 59 join_thread()60void CallbackBase::join_thread() { 61 std::lock_guard<std::mutex> lock(mMutex); 62 join_thread_locked(); 63 } 64 notify()65void CallbackBase::notify() { 66 { 67 std::lock_guard<std::mutex> lock(mMutex); 68 mNotified = true; 69 if (mPostWork != nullptr) { 70 bool success = mPostWork(); 71 if (!success) { 72 LOG(ERROR) << "CallbackBase::notify -- post work failed"; 73 } 74 } 75 } 76 mCondition.notify_all(); 77 } 78 join_thread_locked()79void CallbackBase::join_thread_locked() { 80 if (mThread.joinable()) { 81 mThread.join(); 82 } 83 } 84 PreparedModelCallback()85PreparedModelCallback::PreparedModelCallback() : 86 mErrorStatus(ErrorStatus::GENERAL_FAILURE), mPreparedModel(nullptr) {} 87 ~PreparedModelCallback()88PreparedModelCallback::~PreparedModelCallback() {} 89 notify(ErrorStatus errorStatus,const sp<IPreparedModel> & preparedModel)90Return<void> PreparedModelCallback::notify(ErrorStatus errorStatus, 91 const sp<IPreparedModel>& preparedModel) { 92 mErrorStatus = errorStatus; 93 mPreparedModel = preparedModel; 94 CallbackBase::notify(); 95 return Void(); 96 } 97 getStatus()98ErrorStatus PreparedModelCallback::getStatus() { 99 wait(); 100 return mErrorStatus; 101 } 102 getPreparedModel()103sp<IPreparedModel> PreparedModelCallback::getPreparedModel() { 104 wait(); 105 return mPreparedModel; 106 } 107 ExecutionCallback()108ExecutionCallback::ExecutionCallback() : mErrorStatus(ErrorStatus::GENERAL_FAILURE) {} 109 ~ExecutionCallback()110ExecutionCallback::~ExecutionCallback() {} 111 notify(ErrorStatus errorStatus)112Return<void> ExecutionCallback::notify(ErrorStatus errorStatus) { 113 mErrorStatus = errorStatus; 114 CallbackBase::notify(); 115 return Void(); 116 } 117 getStatus()118ErrorStatus ExecutionCallback::getStatus() { 119 wait(); 120 return mErrorStatus; 121 } 122 123 } // namespace implementation 124 } // namespace V1_0 125 } // namespace neuralnetworks 126 } // namespace hardware 127 } // namespace android 128