• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include "VtsHalNeuralnetworks.h"
20 
21 #include "Callbacks.h"
22 #include "ExecutionBurstController.h"
23 #include "ExecutionBurstServer.h"
24 #include "TestHarness.h"
25 #include "Utils.h"
26 
27 #include <android-base/logging.h>
28 #include <cstring>
29 
30 namespace android {
31 namespace hardware {
32 namespace neuralnetworks {
33 namespace V1_2 {
34 namespace vts {
35 namespace functional {
36 
37 using ::android::nn::ExecutionBurstController;
38 using ::android::nn::RequestChannelSender;
39 using ::android::nn::ResultChannelReceiver;
40 using ExecutionBurstCallback = ::android::nn::ExecutionBurstController::ExecutionBurstCallback;
41 
42 // This constant value represents the length of an FMQ that is large enough to
43 // return a result from a burst execution for all of the generated test cases.
44 constexpr size_t kExecutionBurstChannelLength = 1024;
45 
46 // This constant value represents a length of an FMQ that is not large enough
47 // to return a result from a burst execution for some of the generated test
48 // cases.
49 constexpr size_t kExecutionBurstChannelSmallLength = 8;
50 
51 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
52 
badTiming(Timing timing)53 static bool badTiming(Timing timing) {
54     return timing.timeOnDevice == UINT64_MAX && timing.timeInDriver == UINT64_MAX;
55 }
56 
createBurst(const sp<IPreparedModel> & preparedModel,const sp<IBurstCallback> & callback,std::unique_ptr<RequestChannelSender> * sender,std::unique_ptr<ResultChannelReceiver> * receiver,sp<IBurstContext> * context,size_t resultChannelLength=kExecutionBurstChannelLength)57 static void createBurst(const sp<IPreparedModel>& preparedModel, const sp<IBurstCallback>& callback,
58                         std::unique_ptr<RequestChannelSender>* sender,
59                         std::unique_ptr<ResultChannelReceiver>* receiver,
60                         sp<IBurstContext>* context,
61                         size_t resultChannelLength = kExecutionBurstChannelLength) {
62     ASSERT_NE(nullptr, preparedModel.get());
63     ASSERT_NE(nullptr, sender);
64     ASSERT_NE(nullptr, receiver);
65     ASSERT_NE(nullptr, context);
66 
67     // create FMQ objects
68     auto [fmqRequestChannel, fmqRequestDescriptor] =
69             RequestChannelSender::create(kExecutionBurstChannelLength, /*blocking=*/true);
70     auto [fmqResultChannel, fmqResultDescriptor] =
71             ResultChannelReceiver::create(resultChannelLength, /*blocking=*/true);
72     ASSERT_NE(nullptr, fmqRequestChannel.get());
73     ASSERT_NE(nullptr, fmqResultChannel.get());
74     ASSERT_NE(nullptr, fmqRequestDescriptor);
75     ASSERT_NE(nullptr, fmqResultDescriptor);
76 
77     // configure burst
78     ErrorStatus errorStatus;
79     sp<IBurstContext> burstContext;
80     const Return<void> ret = preparedModel->configureExecutionBurst(
81             callback, *fmqRequestDescriptor, *fmqResultDescriptor,
82             [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
83                 errorStatus = status;
84                 burstContext = context;
85             });
86     ASSERT_TRUE(ret.isOk());
87     ASSERT_EQ(ErrorStatus::NONE, errorStatus);
88     ASSERT_NE(nullptr, burstContext.get());
89 
90     // return values
91     *sender = std::move(fmqRequestChannel);
92     *receiver = std::move(fmqResultChannel);
93     *context = burstContext;
94 }
95 
createBurstWithResultChannelLength(const sp<IPreparedModel> & preparedModel,size_t resultChannelLength,std::shared_ptr<ExecutionBurstController> * controller)96 static void createBurstWithResultChannelLength(
97         const sp<IPreparedModel>& preparedModel, size_t resultChannelLength,
98         std::shared_ptr<ExecutionBurstController>* controller) {
99     ASSERT_NE(nullptr, preparedModel.get());
100     ASSERT_NE(nullptr, controller);
101 
102     // create FMQ objects
103     std::unique_ptr<RequestChannelSender> sender;
104     std::unique_ptr<ResultChannelReceiver> receiver;
105     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
106     sp<IBurstContext> context;
107     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context,
108                                         resultChannelLength));
109     ASSERT_NE(nullptr, sender.get());
110     ASSERT_NE(nullptr, receiver.get());
111     ASSERT_NE(nullptr, context.get());
112 
113     // return values
114     *controller = std::make_shared<ExecutionBurstController>(std::move(sender), std::move(receiver),
115                                                              context, callback);
116 }
117 
118 // Primary validation function. This function will take a valid serialized
119 // request, apply a mutation to it to invalidate the serialized request, then
120 // pass it to interface calls that use the serialized request. Note that the
121 // serialized request here is passed by value, and any mutation to the
122 // serialized request does not leave this function.
validate(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::string & message,std::vector<FmqRequestDatum> serialized,const std::function<void (std::vector<FmqRequestDatum> *)> & mutation)123 static void validate(RequestChannelSender* sender, ResultChannelReceiver* receiver,
124                      const std::string& message, std::vector<FmqRequestDatum> serialized,
125                      const std::function<void(std::vector<FmqRequestDatum>*)>& mutation) {
126     mutation(&serialized);
127 
128     // skip if packet is too large to send
129     if (serialized.size() > kExecutionBurstChannelLength) {
130         return;
131     }
132 
133     SCOPED_TRACE(message);
134 
135     // send invalid packet
136     ASSERT_TRUE(sender->sendPacket(serialized));
137 
138     // receive error
139     auto results = receiver->getBlocking();
140     ASSERT_TRUE(results.has_value());
141     const auto [status, outputShapes, timing] = std::move(*results);
142     EXPECT_NE(ErrorStatus::NONE, status);
143     EXPECT_EQ(0u, outputShapes.size());
144     EXPECT_TRUE(badTiming(timing));
145 }
146 
147 // For validation, valid packet entries are mutated to invalid packet entries,
148 // or invalid packet entries are inserted into valid packets. This function
149 // creates pre-set invalid packet entries for convenience.
createBadRequestPacketEntries()150 static std::vector<FmqRequestDatum> createBadRequestPacketEntries() {
151     const FmqRequestDatum::PacketInformation packetInformation = {
152             /*.packetSize=*/10, /*.numberOfInputOperands=*/10, /*.numberOfOutputOperands=*/10,
153             /*.numberOfPools=*/10};
154     const FmqRequestDatum::OperandInformation operandInformation = {
155             /*.hasNoValue=*/false, /*.location=*/{}, /*.numberOfDimensions=*/10};
156     const int32_t invalidPoolIdentifier = std::numeric_limits<int32_t>::max();
157     std::vector<FmqRequestDatum> bad(7);
158     bad[0].packetInformation(packetInformation);
159     bad[1].inputOperandInformation(operandInformation);
160     bad[2].inputOperandDimensionValue(0);
161     bad[3].outputOperandInformation(operandInformation);
162     bad[4].outputOperandDimensionValue(0);
163     bad[5].poolIdentifier(invalidPoolIdentifier);
164     bad[6].measureTiming(MeasureTiming::YES);
165     return bad;
166 }
167 
168 // For validation, valid packet entries are mutated to invalid packet entries,
169 // or invalid packet entries are inserted into valid packets. This function
170 // retrieves pre-set invalid packet entries for convenience. This function
171 // caches these data so they can be reused on subsequent validation checks.
getBadRequestPacketEntries()172 static const std::vector<FmqRequestDatum>& getBadRequestPacketEntries() {
173     static const std::vector<FmqRequestDatum> bad = createBadRequestPacketEntries();
174     return bad;
175 }
176 
177 ///////////////////////// REMOVE DATUM ////////////////////////////////////
178 
removeDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)179 static void removeDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
180                             const std::vector<FmqRequestDatum>& serialized) {
181     for (size_t index = 0; index < serialized.size(); ++index) {
182         const std::string message = "removeDatum: removed datum at index " + std::to_string(index);
183         validate(sender, receiver, message, serialized,
184                  [index](std::vector<FmqRequestDatum>* serialized) {
185                      serialized->erase(serialized->begin() + index);
186                  });
187     }
188 }
189 
190 ///////////////////////// ADD DATUM ////////////////////////////////////
191 
addDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)192 static void addDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
193                          const std::vector<FmqRequestDatum>& serialized) {
194     const std::vector<FmqRequestDatum>& extra = getBadRequestPacketEntries();
195     for (size_t index = 0; index <= serialized.size(); ++index) {
196         for (size_t type = 0; type < extra.size(); ++type) {
197             const std::string message = "addDatum: added datum type " + std::to_string(type) +
198                                         " at index " + std::to_string(index);
199             validate(sender, receiver, message, serialized,
200                      [index, type, &extra](std::vector<FmqRequestDatum>* serialized) {
201                          serialized->insert(serialized->begin() + index, extra[type]);
202                      });
203         }
204     }
205 }
206 
207 ///////////////////////// MUTATE DATUM ////////////////////////////////////
208 
interestingCase(const FmqRequestDatum & lhs,const FmqRequestDatum & rhs)209 static bool interestingCase(const FmqRequestDatum& lhs, const FmqRequestDatum& rhs) {
210     using Discriminator = FmqRequestDatum::hidl_discriminator;
211 
212     const bool differentValues = (lhs != rhs);
213     const bool sameDiscriminator = (lhs.getDiscriminator() == rhs.getDiscriminator());
214     const auto discriminator = rhs.getDiscriminator();
215     const bool isDimensionValue = (discriminator == Discriminator::inputOperandDimensionValue ||
216                                    discriminator == Discriminator::outputOperandDimensionValue);
217 
218     return differentValues && !(sameDiscriminator && isDimensionValue);
219 }
220 
mutateDatumTest(RequestChannelSender * sender,ResultChannelReceiver * receiver,const std::vector<FmqRequestDatum> & serialized)221 static void mutateDatumTest(RequestChannelSender* sender, ResultChannelReceiver* receiver,
222                             const std::vector<FmqRequestDatum>& serialized) {
223     const std::vector<FmqRequestDatum>& change = getBadRequestPacketEntries();
224     for (size_t index = 0; index < serialized.size(); ++index) {
225         for (size_t type = 0; type < change.size(); ++type) {
226             if (interestingCase(serialized[index], change[type])) {
227                 const std::string message = "mutateDatum: changed datum at index " +
228                                             std::to_string(index) + " to datum type " +
229                                             std::to_string(type);
230                 validate(sender, receiver, message, serialized,
231                          [index, type, &change](std::vector<FmqRequestDatum>* serialized) {
232                              (*serialized)[index] = change[type];
233                          });
234             }
235         }
236     }
237 }
238 
239 ///////////////////////// BURST VALIATION TESTS ////////////////////////////////////
240 
validateBurstSerialization(const sp<IPreparedModel> & preparedModel,const std::vector<Request> & requests)241 static void validateBurstSerialization(const sp<IPreparedModel>& preparedModel,
242                                        const std::vector<Request>& requests) {
243     // create burst
244     std::unique_ptr<RequestChannelSender> sender;
245     std::unique_ptr<ResultChannelReceiver> receiver;
246     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
247     sp<IBurstContext> context;
248     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
249     ASSERT_NE(nullptr, sender.get());
250     ASSERT_NE(nullptr, receiver.get());
251     ASSERT_NE(nullptr, context.get());
252 
253     // validate each request
254     for (const Request& request : requests) {
255         // load memory into callback slots
256         std::vector<intptr_t> keys;
257         keys.reserve(request.pools.size());
258         std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
259                        [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
260         const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
261 
262         // ensure slot std::numeric_limits<int32_t>::max() doesn't exist (for
263         // subsequent slot validation testing)
264         ASSERT_TRUE(std::all_of(slots.begin(), slots.end(), [](int32_t slot) {
265             return slot != std::numeric_limits<int32_t>::max();
266         }));
267 
268         // serialize the request
269         const auto serialized = ::android::nn::serialize(request, MeasureTiming::YES, slots);
270 
271         // validations
272         removeDatumTest(sender.get(), receiver.get(), serialized);
273         addDatumTest(sender.get(), receiver.get(), serialized);
274         mutateDatumTest(sender.get(), receiver.get(), serialized);
275     }
276 }
277 
278 // This test validates that when the Result message size exceeds length of the
279 // result FMQ, the service instance gracefully fails and returns an error.
validateBurstFmqLength(const sp<IPreparedModel> & preparedModel,const std::vector<Request> & requests)280 static void validateBurstFmqLength(const sp<IPreparedModel>& preparedModel,
281                                    const std::vector<Request>& requests) {
282     // create regular burst
283     std::shared_ptr<ExecutionBurstController> controllerRegular;
284     ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
285             preparedModel, kExecutionBurstChannelLength, &controllerRegular));
286     ASSERT_NE(nullptr, controllerRegular.get());
287 
288     // create burst with small output channel
289     std::shared_ptr<ExecutionBurstController> controllerSmall;
290     ASSERT_NO_FATAL_FAILURE(createBurstWithResultChannelLength(
291             preparedModel, kExecutionBurstChannelSmallLength, &controllerSmall));
292     ASSERT_NE(nullptr, controllerSmall.get());
293 
294     // validate each request
295     for (const Request& request : requests) {
296         // load memory into callback slots
297         std::vector<intptr_t> keys(request.pools.size());
298         for (size_t i = 0; i < keys.size(); ++i) {
299             keys[i] = reinterpret_cast<intptr_t>(&request.pools[i]);
300         }
301 
302         // collect serialized result by running regular burst
303         const auto [statusRegular, outputShapesRegular, timingRegular] =
304                 controllerRegular->compute(request, MeasureTiming::NO, keys);
305 
306         // skip test if regular burst output isn't useful for testing a failure
307         // caused by having too small of a length for the result FMQ
308         const std::vector<FmqResultDatum> serialized =
309                 ::android::nn::serialize(statusRegular, outputShapesRegular, timingRegular);
310         if (statusRegular != ErrorStatus::NONE ||
311             serialized.size() <= kExecutionBurstChannelSmallLength) {
312             continue;
313         }
314 
315         // by this point, execution should fail because the result channel isn't
316         // large enough to return the serialized result
317         const auto [statusSmall, outputShapesSmall, timingSmall] =
318                 controllerSmall->compute(request, MeasureTiming::NO, keys);
319         EXPECT_NE(ErrorStatus::NONE, statusSmall);
320         EXPECT_EQ(0u, outputShapesSmall.size());
321         EXPECT_TRUE(badTiming(timingSmall));
322     }
323 }
324 
isSanitized(const FmqResultDatum & datum)325 static bool isSanitized(const FmqResultDatum& datum) {
326     using Discriminator = FmqResultDatum::hidl_discriminator;
327 
328     // check to ensure the padding values in the returned
329     // FmqResultDatum::OperandInformation are initialized to 0
330     if (datum.getDiscriminator() == Discriminator::operandInformation) {
331         static_assert(
332                 offsetof(FmqResultDatum::OperandInformation, isSufficient) == 0,
333                 "unexpected value for offset of FmqResultDatum::OperandInformation::isSufficient");
334         static_assert(
335                 sizeof(FmqResultDatum::OperandInformation::isSufficient) == 1,
336                 "unexpected value for size of FmqResultDatum::OperandInformation::isSufficient");
337         static_assert(offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) == 4,
338                       "unexpected value for offset of "
339                       "FmqResultDatum::OperandInformation::numberOfDimensions");
340         static_assert(sizeof(FmqResultDatum::OperandInformation::numberOfDimensions) == 4,
341                       "unexpected value for size of "
342                       "FmqResultDatum::OperandInformation::numberOfDimensions");
343         static_assert(sizeof(FmqResultDatum::OperandInformation) == 8,
344                       "unexpected value for size of "
345                       "FmqResultDatum::OperandInformation");
346 
347         constexpr size_t paddingOffset =
348                 offsetof(FmqResultDatum::OperandInformation, isSufficient) +
349                 sizeof(FmqResultDatum::OperandInformation::isSufficient);
350         constexpr size_t paddingSize =
351                 offsetof(FmqResultDatum::OperandInformation, numberOfDimensions) - paddingOffset;
352 
353         FmqResultDatum::OperandInformation initialized{};
354         std::memset(&initialized, 0, sizeof(initialized));
355 
356         const char* initializedPaddingStart =
357                 reinterpret_cast<const char*>(&initialized) + paddingOffset;
358         const char* datumPaddingStart =
359                 reinterpret_cast<const char*>(&datum.operandInformation()) + paddingOffset;
360 
361         return std::memcmp(datumPaddingStart, initializedPaddingStart, paddingSize) == 0;
362     }
363 
364     // there are no other padding initialization checks required, so return true
365     // for any sum-type that isn't FmqResultDatum::OperandInformation
366     return true;
367 }
368 
validateBurstSanitized(const sp<IPreparedModel> & preparedModel,const std::vector<Request> & requests)369 static void validateBurstSanitized(const sp<IPreparedModel>& preparedModel,
370                                    const std::vector<Request>& requests) {
371     // create burst
372     std::unique_ptr<RequestChannelSender> sender;
373     std::unique_ptr<ResultChannelReceiver> receiver;
374     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
375     sp<IBurstContext> context;
376     ASSERT_NO_FATAL_FAILURE(createBurst(preparedModel, callback, &sender, &receiver, &context));
377     ASSERT_NE(nullptr, sender.get());
378     ASSERT_NE(nullptr, receiver.get());
379     ASSERT_NE(nullptr, context.get());
380 
381     // validate each request
382     for (const Request& request : requests) {
383         // load memory into callback slots
384         std::vector<intptr_t> keys;
385         keys.reserve(request.pools.size());
386         std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(keys),
387                        [](const auto& pool) { return reinterpret_cast<intptr_t>(&pool); });
388         const std::vector<int32_t> slots = callback->getSlots(request.pools, keys);
389 
390         // send valid request
391         ASSERT_TRUE(sender->send(request, MeasureTiming::YES, slots));
392 
393         // receive valid result
394         auto serialized = receiver->getPacketBlocking();
395         ASSERT_TRUE(serialized.has_value());
396 
397         // sanitize result
398         ASSERT_TRUE(std::all_of(serialized->begin(), serialized->end(), isSanitized))
399                 << "The result serialized data is not properly sanitized";
400     }
401 }
402 
403 ///////////////////////////// ENTRY POINT //////////////////////////////////
404 
validateBurst(const sp<IPreparedModel> & preparedModel,const std::vector<Request> & requests)405 void ValidationTest::validateBurst(const sp<IPreparedModel>& preparedModel,
406                                    const std::vector<Request>& requests) {
407     ASSERT_NO_FATAL_FAILURE(validateBurstSerialization(preparedModel, requests));
408     ASSERT_NO_FATAL_FAILURE(validateBurstFmqLength(preparedModel, requests));
409     ASSERT_NO_FATAL_FAILURE(validateBurstSanitized(preparedModel, requests));
410 }
411 
412 }  // namespace functional
413 }  // namespace vts
414 }  // namespace V1_2
415 }  // namespace neuralnetworks
416 }  // namespace hardware
417 }  // namespace android
418