1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ExecutionCallback"
18
19 #include "ExecutionCallback.h"
20
21 #include <android-base/logging.h>
22
23 #include <limits>
24 #include <utility>
25 #include <vector>
26
27 namespace android::nn {
28
notify(ErrorStatus status,const std::vector<OutputShape> & outputShapes,const Timing & timing)29 void ExecutionCallback::notify(ErrorStatus status, const std::vector<OutputShape>& outputShapes,
30 const Timing& timing) {
31 notifyInternal(status, outputShapes, timing);
32 }
33
wait() const34 void ExecutionCallback::wait() const {
35 std::unique_lock<std::mutex> lock(mMutex);
36 mCondition.wait(lock, [this] { return mNotified; });
37
38 /*
39 * Note that we cannot call std::thread::join from ExecutionCallback's
40 * destructor: ExecutionCallback is intended to be reference counted, and it
41 * is possible that the reference count drops to zero in the bound thread,
42 * causing the bound thread to call this destructor. If a thread tries to
43 * join itself, it throws an exception, producing a message like the
44 * following:
45 *
46 * terminating with uncaught exception of type std::__1::system_error:
47 * thread::join failed: Resource deadlock would occur
48 */
49 if (mThread.joinable()) {
50 mThread.join();
51 }
52 }
53
getStatus() const54 ErrorStatus ExecutionCallback::getStatus() const {
55 wait();
56 return mErrorStatus;
57 }
58
getOutputShapes() const59 const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
60 wait();
61 return mOutputShapes;
62 }
63
getTiming() const64 Timing ExecutionCallback::getTiming() const {
65 wait();
66 return mTiming;
67 }
68
bindThread(std::thread asyncThread)69 bool ExecutionCallback::bindThread(std::thread asyncThread) {
70 std::lock_guard<std::mutex> lock(mMutex);
71
72 // Ensure ExecutionCallback object does not already have a thread bound
73 if (mThread.joinable()) {
74 LOG(ERROR) << "ExecutionCallback::bindThread -- a thread has already been bound to this "
75 "callback object";
76 return false;
77 }
78
79 // Ensure the new thread is valid
80 if (!asyncThread.joinable()) {
81 LOG(ERROR) << "ExecutionCallback::bindThread -- the new thread is not joinable";
82 return false;
83 }
84
85 mThread = std::move(asyncThread);
86 return true;
87 }
88
setOnFinish(const ExecutionFinish & finish)89 void ExecutionCallback::setOnFinish(const ExecutionFinish& finish) {
90 std::lock_guard<std::mutex> hold(mMutex);
91
92 // Ensure ExecutionCallback object does not already have a "finish" callback
93 if (mOnFinish != nullptr) {
94 LOG(ERROR) << "ExecutionCallback::setOnFinish -- object already has a \"finish\" callback";
95 return;
96 }
97
98 // Ensure new "finish" callback is valid
99 if (finish == nullptr) {
100 LOG(ERROR) << "ExecutionCallback::setOnFinish -- \"finish\" callback is invalid";
101 return;
102 }
103
104 // Essure ExecutionCallback object has not already been notified
105 if (mNotified) {
106 LOG(ERROR) << "ExecutionCallback::setOnFinish -- ExecutionCallback has already been "
107 "notified with results";
108 return;
109 }
110
111 mOnFinish = finish;
112 }
113
notifyInternal(ErrorStatus errorStatus,std::vector<OutputShape> outputShapes,Timing timing)114 void ExecutionCallback::notifyInternal(ErrorStatus errorStatus,
115 std::vector<OutputShape> outputShapes, Timing timing) {
116 // check results
117 {
118 if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
119 // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
120 if (outputShapes.size() == 0) {
121 LOG(ERROR)
122 << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
123 errorStatus = ErrorStatus::GENERAL_FAILURE;
124 outputShapes = {};
125 timing = {};
126 }
127 } else if (errorStatus != ErrorStatus::NONE) {
128 // outputShapes must be empty if errorStatus is neither NONE nor
129 // OUTPUT_INSUFFICIENT_SIZE.
130 if (outputShapes.size() != 0) {
131 LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
132 "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
133 errorStatus = ErrorStatus::GENERAL_FAILURE;
134 outputShapes = {};
135 timing = {};
136 }
137 }
138 }
139
140 // store results
141 {
142 std::lock_guard<std::mutex> hold(mMutex);
143
144 // quick-return if object has already been notified
145 if (mNotified) {
146 return;
147 }
148
149 mErrorStatus = errorStatus;
150 mOutputShapes = std::move(outputShapes);
151 mTiming = timing;
152 mNotified = true;
153
154 if (mOnFinish != nullptr) {
155 ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
156 mOnFinish = nullptr;
157 if (status != ErrorStatus::NONE) {
158 mErrorStatus = status;
159 }
160 }
161 }
162 mCondition.notify_all();
163 }
164
165 } // namespace android::nn
166