• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
18 #define ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
19 
20 #include "HalInterfaces.h"
21 
22 #include <android-base/macros.h>
23 #include <fmq/MessageQueue.h>
24 #include <hidl/MQDescriptor.h>
25 
26 #include <atomic>
27 #include <memory>
28 #include <optional>
29 #include <thread>
30 #include <vector>
31 
32 namespace android::nn {
33 
34 using ::android::hardware::MQDescriptorSync;
35 using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
36 using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
37 
38 /**
39  * Function to serialize results.
40  *
41  * Prefer calling ResultChannelSender::send.
42  *
43  * @param errorStatus Status of the execution.
44  * @param outputShapes Dynamic shapes of the output tensors.
45  * @param timing Timing information of the execution.
46  * @return Serialized FMQ result data.
47  */
48 std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
49                                       const std::vector<OutputShape>& outputShapes, Timing timing);
50 
51 /**
52  * Deserialize the FMQ request data.
53  *
54  * The three resulting fields are the Request object (where Request::pools is
55  * empty), slot identifiers (which are stand-ins for Request::pools), and
56  * whether timing information must be collected for the run.
57  *
58  * @param data Serialized FMQ request data.
59  * @return Request object if successfully deserialized, std::nullopt otherwise.
60  */
61 std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
62         const std::vector<FmqRequestDatum>& data);
63 
64 /**
65  * RequestChannelReceiver is responsible for waiting on the channel until the
66  * packet is available, extracting the packet from the channel, and
67  * deserializing the packet.
68  *
69  * Because the receiver can wait on a packet that may never come (e.g., because
70  * the sending side of the packet has been closed), this object can be
71  * invalidating, unblocking the receiver.
72  */
73 class RequestChannelReceiver {
74     using FmqRequestChannel =
75             hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>;
76 
77    public:
78     /**
79      * Create the receiving end of a request channel.
80      *
81      * Prefer this call over the constructor.
82      *
83      * @param requestChannel Descriptor for the request channel.
84      * @return RequestChannelReceiver on successful creation, nullptr otherwise.
85      */
86     static std::unique_ptr<RequestChannelReceiver> create(
87             const FmqRequestDescriptor& requestChannel);
88 
89     /**
90      * Get the request from the channel.
91      *
92      * This method will block until either:
93      * 1) The packet has been retrieved, or
94      * 2) The receiver has been invalidated
95      *
96      * @return Request object if successfully received, std::nullopt if error or
97      *     if the receiver object was invalidated.
98      */
99     std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> getBlocking();
100 
101     /**
102      * Method to mark the channel as invalid, unblocking any current or future
103      * calls to RequestChannelReceiver::getBlocking.
104      */
105     void invalidate();
106 
107     RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);
108 
109    private:
110     std::optional<std::vector<FmqRequestDatum>> getPacketBlocking();
111 
112     const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
113     std::atomic<bool> mTeardown{false};
114     const bool mBlocking;
115 };
116 
117 /**
118  * ResultChannelSender is responsible for serializing the result packet of
119  * information, sending it on the result channel, and signaling that the data is
120  * available.
121  */
122 class ResultChannelSender {
123     using FmqResultChannel =
124             hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>;
125 
126    public:
127     /**
128      * Create the sending end of a result channel.
129      *
130      * Prefer this call over the constructor.
131      *
132      * @param resultChannel Descriptor for the result channel.
133      * @return ResultChannelSender on successful creation, nullptr otherwise.
134      */
135     static std::unique_ptr<ResultChannelSender> create(const FmqResultDescriptor& resultChannel);
136 
137     /**
138      * Send the result to the channel.
139      *
140      * @param errorStatus Status of the execution.
141      * @param outputShapes Dynamic shapes of the output tensors.
142      * @param timing Timing information of the execution.
143      * @return 'true' on successful send, 'false' otherwise.
144      */
145     bool send(ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing);
146 
147     // prefer calling ResultChannelSender::send
148     bool sendPacket(const std::vector<FmqResultDatum>& packet);
149 
150     ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);
151 
152    private:
153     const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
154     const bool mBlocking;
155 };
156 
157 /**
158  * The ExecutionBurstServer class is responsible for waiting for and
159  * deserializing a request object from a FMQ, performing the inference, and
160  * serializing the result back across another FMQ.
161  */
162 class ExecutionBurstServer : public IBurstContext {
163     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
164 
165    public:
166     /**
167      * IBurstExecutorWithCache is a callback object passed to
168      * ExecutionBurstServer's factory function that is used to perform an
169      * execution. Because some memory resources are needed across multiple
170      * executions, this object also contains a local cache that can directly be
171      * used in the execution.
172      *
173      * ExecutionBurstServer will never access its IBurstExecutorWithCache object
174      * with concurrent calls.
175      */
176     class IBurstExecutorWithCache {
177         DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache);
178 
179        public:
180         IBurstExecutorWithCache() = default;
181         virtual ~IBurstExecutorWithCache() = default;
182 
183         /**
184          * Checks if a cache entry specified by a slot is present in the cache.
185          *
186          * @param slot Identifier of the cache entry.
187          * @return 'true' if the cache entry is present in the cache, 'false'
188          *     otherwise.
189          */
190         virtual bool isCacheEntryPresent(int32_t slot) const = 0;
191 
192         /**
193          * Adds an entry specified by a slot to the cache.
194          *
195          * The caller of this function must ensure that the cache entry that is
196          * being added is not already present in the cache. This can be checked
197          * via isCacheEntryPresent.
198          *
199          * @param memory Memory resource to be cached.
200          * @param slot Slot identifier corresponding to the memory resource.
201          */
202         virtual void addCacheEntry(const hidl_memory& memory, int32_t slot) = 0;
203 
204         /**
205          * Removes an entry specified by a slot from the cache.
206          *
207          * If the cache entry corresponding to the slot number does not exist,
208          * the call does nothing.
209          *
210          * @param slot Slot identifier corresponding to the memory resource.
211          */
212         virtual void removeCacheEntry(int32_t slot) = 0;
213 
214         /**
215          * Perform an execution.
216          *
217          * @param request Request object with inputs and outputs specified.
218          *     Request::pools is empty, and DataLocation::poolIndex instead
219          *     refers to the 'slots' argument as if it were Request::pools.
220          * @param slots Slots corresponding to the cached memory entries to be
221          *     used.
222          * @param measure Whether timing information is requested for the
223          *     execution.
224          * @return Result of the execution, including the status of the
225          *     execution, dynamic output shapes, and any timing information.
226          */
227         virtual std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
228                 const Request& request, const std::vector<int32_t>& slots,
229                 MeasureTiming measure) = 0;
230     };
231 
232     /**
233      * Create automated context to manage FMQ-based executions.
234      *
235      * This function is intended to be used by a service to automatically:
236      * 1) Receive data from a provided FMQ
237      * 2) Execute a model with the given information
238      * 3) Send the result to the created FMQ
239      *
240      * @param callback Callback used to retrieve memories corresponding to
241      *     unrecognized slots.
242      * @param requestChannel Input FMQ channel through which the client passes the
243      *     request to the service.
244      * @param resultChannel Output FMQ channel from which the client can retrieve
245      *     the result of the execution.
246      * @param executorWithCache Object which maintains a local cache of the
247      *     memory pools and executes using the cached memory pools.
248      * @result IBurstContext Handle to the burst context.
249      */
250     static sp<ExecutionBurstServer> create(
251             const sp<IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel,
252             const FmqResultDescriptor& resultChannel,
253             std::shared_ptr<IBurstExecutorWithCache> executorWithCache);
254 
255     /**
256      * Create automated context to manage FMQ-based executions.
257      *
258      * This function is intended to be used by a service to automatically:
259      * 1) Receive data from a provided FMQ
260      * 2) Execute a model with the given information
261      * 3) Send the result to the created FMQ
262      *
263      * @param callback Callback used to retrieve memories corresponding to
264      *     unrecognized slots.
265      * @param requestChannel Input FMQ channel through which the client passes the
266      *     request to the service.
267      * @param resultChannel Output FMQ channel from which the client can retrieve
268      *     the result of the execution.
269      * @param preparedModel PreparedModel that the burst object was created from.
270      *     IPreparedModel::executeSynchronously will be used to perform the
271      *     execution.
272      * @result IBurstContext Handle to the burst context.
273      */
274     static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback,
275                                            const FmqRequestDescriptor& requestChannel,
276                                            const FmqResultDescriptor& resultChannel,
277                                            IPreparedModel* preparedModel);
278 
279     ExecutionBurstServer(const sp<IBurstCallback>& callback,
280                          std::unique_ptr<RequestChannelReceiver> requestChannel,
281                          std::unique_ptr<ResultChannelSender> resultChannel,
282                          std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
283     ~ExecutionBurstServer();
284 
285     // Used by the NN runtime to preemptively remove any stored memory.
286     Return<void> freeMemory(int32_t slot) override;
287 
288    private:
289     // Ensures all cache entries contained in mExecutorWithCache are present in
290     // the cache. If they are not present, they are retrieved (via
291     // IBurstCallback::getMemories) and added to mExecutorWithCache.
292     //
293     // This method is locked via mMutex when it is called.
294     void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots);
295 
296     // Work loop that will continue processing execution requests until the
297     // ExecutionBurstServer object is freed.
298     void task();
299 
300     std::thread mWorker;
301     std::mutex mMutex;
302     std::atomic<bool> mTeardown{false};
303     const sp<IBurstCallback> mCallback;
304     const std::unique_ptr<RequestChannelReceiver> mRequestChannelReceiver;
305     const std::unique_ptr<ResultChannelSender> mResultChannelSender;
306     const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
307 };
308 
309 }  // namespace android::nn
310 
311 #endif  // ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
312