1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H 18 #define ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H 19 20 #include "HalInterfaces.h" 21 22 #include <android-base/macros.h> 23 #include <fmq/MessageQueue.h> 24 #include <hidl/MQDescriptor.h> 25 26 #include <atomic> 27 #include <memory> 28 #include <optional> 29 #include <thread> 30 #include <vector> 31 32 namespace android::nn { 33 34 using ::android::hardware::MQDescriptorSync; 35 using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>; 36 using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>; 37 38 /** 39 * Function to serialize results. 40 * 41 * Prefer calling ResultChannelSender::send. 42 * 43 * @param errorStatus Status of the execution. 44 * @param outputShapes Dynamic shapes of the output tensors. 45 * @param timing Timing information of the execution. 46 * @return Serialized FMQ result data. 47 */ 48 std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus, 49 const std::vector<OutputShape>& outputShapes, Timing timing); 50 51 /** 52 * Deserialize the FMQ request data. 53 * 54 * The three resulting fields are the Request object (where Request::pools is 55 * empty), slot identifiers (which are stand-ins for Request::pools), and 56 * whether timing information must be collected for the run. 57 * 58 * @param data Serialized FMQ request data. 59 * @return Request object if successfully deserialized, std::nullopt otherwise. 60 */ 61 std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize( 62 const std::vector<FmqRequestDatum>& data); 63 64 /** 65 * RequestChannelReceiver is responsible for waiting on the channel until the 66 * packet is available, extracting the packet from the channel, and 67 * deserializing the packet. 68 * 69 * Because the receiver can wait on a packet that may never come (e.g., because 70 * the sending side of the packet has been closed), this object can be 71 * invalidating, unblocking the receiver. 72 */ 73 class RequestChannelReceiver { 74 using FmqRequestChannel = 75 hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>; 76 77 public: 78 /** 79 * Create the receiving end of a request channel. 80 * 81 * Prefer this call over the constructor. 82 * 83 * @param requestChannel Descriptor for the request channel. 84 * @return RequestChannelReceiver on successful creation, nullptr otherwise. 85 */ 86 static std::unique_ptr<RequestChannelReceiver> create( 87 const FmqRequestDescriptor& requestChannel); 88 89 /** 90 * Get the request from the channel. 91 * 92 * This method will block until either: 93 * 1) The packet has been retrieved, or 94 * 2) The receiver has been invalidated 95 * 96 * @return Request object if successfully received, std::nullopt if error or 97 * if the receiver object was invalidated. 98 */ 99 std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> getBlocking(); 100 101 /** 102 * Method to mark the channel as invalid, unblocking any current or future 103 * calls to RequestChannelReceiver::getBlocking. 104 */ 105 void invalidate(); 106 107 RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking); 108 109 private: 110 std::optional<std::vector<FmqRequestDatum>> getPacketBlocking(); 111 112 const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel; 113 std::atomic<bool> mTeardown{false}; 114 const bool mBlocking; 115 }; 116 117 /** 118 * ResultChannelSender is responsible for serializing the result packet of 119 * information, sending it on the result channel, and signaling that the data is 120 * available. 121 */ 122 class ResultChannelSender { 123 using FmqResultChannel = 124 hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>; 125 126 public: 127 /** 128 * Create the sending end of a result channel. 129 * 130 * Prefer this call over the constructor. 131 * 132 * @param resultChannel Descriptor for the result channel. 133 * @return ResultChannelSender on successful creation, nullptr otherwise. 134 */ 135 static std::unique_ptr<ResultChannelSender> create(const FmqResultDescriptor& resultChannel); 136 137 /** 138 * Send the result to the channel. 139 * 140 * @param errorStatus Status of the execution. 141 * @param outputShapes Dynamic shapes of the output tensors. 142 * @param timing Timing information of the execution. 143 * @return 'true' on successful send, 'false' otherwise. 144 */ 145 bool send(ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing); 146 147 // prefer calling ResultChannelSender::send 148 bool sendPacket(const std::vector<FmqResultDatum>& packet); 149 150 ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking); 151 152 private: 153 const std::unique_ptr<FmqResultChannel> mFmqResultChannel; 154 const bool mBlocking; 155 }; 156 157 /** 158 * The ExecutionBurstServer class is responsible for waiting for and 159 * deserializing a request object from a FMQ, performing the inference, and 160 * serializing the result back across another FMQ. 161 */ 162 class ExecutionBurstServer : public IBurstContext { 163 DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer); 164 165 public: 166 /** 167 * IBurstExecutorWithCache is a callback object passed to 168 * ExecutionBurstServer's factory function that is used to perform an 169 * execution. Because some memory resources are needed across multiple 170 * executions, this object also contains a local cache that can directly be 171 * used in the execution. 172 * 173 * ExecutionBurstServer will never access its IBurstExecutorWithCache object 174 * with concurrent calls. 175 */ 176 class IBurstExecutorWithCache { 177 DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache); 178 179 public: 180 IBurstExecutorWithCache() = default; 181 virtual ~IBurstExecutorWithCache() = default; 182 183 /** 184 * Checks if a cache entry specified by a slot is present in the cache. 185 * 186 * @param slot Identifier of the cache entry. 187 * @return 'true' if the cache entry is present in the cache, 'false' 188 * otherwise. 189 */ 190 virtual bool isCacheEntryPresent(int32_t slot) const = 0; 191 192 /** 193 * Adds an entry specified by a slot to the cache. 194 * 195 * The caller of this function must ensure that the cache entry that is 196 * being added is not already present in the cache. This can be checked 197 * via isCacheEntryPresent. 198 * 199 * @param memory Memory resource to be cached. 200 * @param slot Slot identifier corresponding to the memory resource. 201 */ 202 virtual void addCacheEntry(const hidl_memory& memory, int32_t slot) = 0; 203 204 /** 205 * Removes an entry specified by a slot from the cache. 206 * 207 * If the cache entry corresponding to the slot number does not exist, 208 * the call does nothing. 209 * 210 * @param slot Slot identifier corresponding to the memory resource. 211 */ 212 virtual void removeCacheEntry(int32_t slot) = 0; 213 214 /** 215 * Perform an execution. 216 * 217 * @param request Request object with inputs and outputs specified. 218 * Request::pools is empty, and DataLocation::poolIndex instead 219 * refers to the 'slots' argument as if it were Request::pools. 220 * @param slots Slots corresponding to the cached memory entries to be 221 * used. 222 * @param measure Whether timing information is requested for the 223 * execution. 224 * @return Result of the execution, including the status of the 225 * execution, dynamic output shapes, and any timing information. 226 */ 227 virtual std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute( 228 const Request& request, const std::vector<int32_t>& slots, 229 MeasureTiming measure) = 0; 230 }; 231 232 /** 233 * Create automated context to manage FMQ-based executions. 234 * 235 * This function is intended to be used by a service to automatically: 236 * 1) Receive data from a provided FMQ 237 * 2) Execute a model with the given information 238 * 3) Send the result to the created FMQ 239 * 240 * @param callback Callback used to retrieve memories corresponding to 241 * unrecognized slots. 242 * @param requestChannel Input FMQ channel through which the client passes the 243 * request to the service. 244 * @param resultChannel Output FMQ channel from which the client can retrieve 245 * the result of the execution. 246 * @param executorWithCache Object which maintains a local cache of the 247 * memory pools and executes using the cached memory pools. 248 * @result IBurstContext Handle to the burst context. 249 */ 250 static sp<ExecutionBurstServer> create( 251 const sp<IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel, 252 const FmqResultDescriptor& resultChannel, 253 std::shared_ptr<IBurstExecutorWithCache> executorWithCache); 254 255 /** 256 * Create automated context to manage FMQ-based executions. 257 * 258 * This function is intended to be used by a service to automatically: 259 * 1) Receive data from a provided FMQ 260 * 2) Execute a model with the given information 261 * 3) Send the result to the created FMQ 262 * 263 * @param callback Callback used to retrieve memories corresponding to 264 * unrecognized slots. 265 * @param requestChannel Input FMQ channel through which the client passes the 266 * request to the service. 267 * @param resultChannel Output FMQ channel from which the client can retrieve 268 * the result of the execution. 269 * @param preparedModel PreparedModel that the burst object was created from. 270 * IPreparedModel::executeSynchronously will be used to perform the 271 * execution. 272 * @result IBurstContext Handle to the burst context. 273 */ 274 static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback, 275 const FmqRequestDescriptor& requestChannel, 276 const FmqResultDescriptor& resultChannel, 277 IPreparedModel* preparedModel); 278 279 ExecutionBurstServer(const sp<IBurstCallback>& callback, 280 std::unique_ptr<RequestChannelReceiver> requestChannel, 281 std::unique_ptr<ResultChannelSender> resultChannel, 282 std::shared_ptr<IBurstExecutorWithCache> cachedExecutor); 283 ~ExecutionBurstServer(); 284 285 // Used by the NN runtime to preemptively remove any stored memory. 286 Return<void> freeMemory(int32_t slot) override; 287 288 private: 289 // Ensures all cache entries contained in mExecutorWithCache are present in 290 // the cache. If they are not present, they are retrieved (via 291 // IBurstCallback::getMemories) and added to mExecutorWithCache. 292 // 293 // This method is locked via mMutex when it is called. 294 void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots); 295 296 // Work loop that will continue processing execution requests until the 297 // ExecutionBurstServer object is freed. 298 void task(); 299 300 std::thread mWorker; 301 std::mutex mMutex; 302 std::atomic<bool> mTeardown{false}; 303 const sp<IBurstCallback> mCallback; 304 const std::unique_ptr<RequestChannelReceiver> mRequestChannelReceiver; 305 const std::unique_ptr<ResultChannelSender> mResultChannelSender; 306 const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache; 307 }; 308 309 } // namespace android::nn 310 311 #endif // ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H 312