• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ValidateHal"
18 
19 #include "ValidateHal.h"
20 #include "NeuralNetworks.h"
21 #include "Utils.h"
22 
23 #include <android-base/logging.h>
24 
25 namespace android {
26 namespace nn {
27 
28 class MemoryAccessVerifier {
29 public:
MemoryAccessVerifier(const hidl_vec<hidl_memory> & pools)30     MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
31         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
32         for (size_t i = 0; i < mPoolCount; i++) {
33             mPoolSizes[i] = pools[i].size();
34         }
35     }
validate(const DataLocation & location)36     bool validate(const DataLocation& location) {
37         if (location.poolIndex >= mPoolCount) {
38             LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
39             return false;
40         }
41         const size_t size = mPoolSizes[location.poolIndex];
42         // Do the addition using size_t to avoid potential wrap-around problems.
43         if (static_cast<size_t>(location.offset) + location.length > size) {
44             LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
45                        << location.offset << " and length " << location.length
46                        << " exceeds pool size of " << size;
47             return false;
48         }
49         return true;
50     }
51 
52 private:
53     size_t mPoolCount;
54     std::vector<size_t> mPoolSizes;
55 };
56 
validateOperands(const hidl_vec<Operand> & operands,const hidl_vec<uint8_t> & operandValues,const hidl_vec<hidl_memory> & pools)57 static bool validateOperands(const hidl_vec<Operand>& operands,
58                              const hidl_vec<uint8_t>& operandValues,
59                              const hidl_vec<hidl_memory>& pools) {
60     uint32_t index = 0;
61     MemoryAccessVerifier poolVerifier(pools);
62     for (auto& operand : operands) {
63         // Validate type and dimensions.
64         switch (operand.type) {
65             case OperandType::FLOAT32:
66             case OperandType::INT32:
67             case OperandType::UINT32:
68             case OperandType::OEM: {
69                 size_t count = operand.dimensions.size();
70                 if (count != 0) {
71                     LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
72                                << count;
73                     return false;
74                 }
75                 break;
76             }
77             case OperandType::TENSOR_FLOAT32:
78             case OperandType::TENSOR_INT32:
79             case OperandType::TENSOR_QUANT8_ASYMM:
80             case OperandType::TENSOR_OEM_BYTE: {
81                 if (operand.dimensions.size() == 0) {
82                     LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
83                     return false;
84                 }
85                 break;
86             }
87             default:
88                 LOG(ERROR) << "Operand " << index << ": Invalid operand type "
89                            << toString(operand.type);
90                 return false;
91         }
92 
93         // TODO Validate the numberOfConsumers.
94         // TODO Since we have to validate it, there was no point in including it. For the next
95         // release, consider removing unless we have an additional process in system space
96         // that creates this value. In that case, it would not have to be validated.
97 
98         // Validate the scale.
99         switch (operand.type) {
100             case OperandType::FLOAT32:
101             case OperandType::INT32:
102             case OperandType::UINT32:
103             case OperandType::TENSOR_FLOAT32:
104                 if (operand.scale != 0.f) {
105                     LOG(ERROR) << "Operand " << index << ": Operand of type "
106                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
107                                << operand.scale << ")";
108                     return false;
109                 }
110                 break;
111             case OperandType::TENSOR_INT32:
112                 // TENSOR_INT32 may be used with or without scale, depending on the operation.
113                 if (operand.scale < 0.f) {
114                     LOG(ERROR) << "Operand " << index << ": Operand of type "
115                                << getOperandTypeName(operand.type) << " with a negative scale";
116                     return false;
117                 }
118                 break;
119             case OperandType::TENSOR_QUANT8_ASYMM:
120                 if (operand.scale <= 0.f) {
121                     LOG(ERROR) << "Operand " << index << ": Operand of type "
122                                << getOperandTypeName(operand.type) << " with a non-positive scale";
123                     return false;
124                 }
125                 break;
126             default:
127                 // No validation for the OEM types.
128                 // TODO We should have had a separate type for TENSOR_INT32 that a scale
129                 // and those who don't.  Document now and fix in the next release.
130                 break;
131         }
132 
133         // Validate the zeroPoint.
134         switch (operand.type) {
135             case OperandType::FLOAT32:
136             case OperandType::INT32:
137             case OperandType::UINT32:
138             case OperandType::TENSOR_FLOAT32:
139             case OperandType::TENSOR_INT32:
140                 if (operand.zeroPoint != 0) {
141                     LOG(ERROR) << "Operand " << index << ": Operand of type "
142                                << getOperandTypeName(operand.type) << " with an non-zero zeroPoint "
143                                << operand.zeroPoint;
144                     return false;
145                 }
146                 break;
147             case OperandType::TENSOR_QUANT8_ASYMM:
148                 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
149                     LOG(ERROR) << "Operand " << index << ": Operand of type "
150                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
151                                << operand.zeroPoint << ", must be in range [0, 255]";
152                     return false;
153                 }
154                 break;
155             default:
156                 // No validation for the OEM types.
157                 break;
158         }
159 
160         // Validate the lifetime and the location.
161         const DataLocation& location = operand.location;
162         switch (operand.lifetime) {
163             case OperandLifeTime::CONSTANT_COPY:
164                 if (location.poolIndex != 0) {
165                     LOG(ERROR) << "Operand " << index
166                                << ": CONSTANT_COPY with a non-zero poolIndex "
167                                << location.poolIndex;
168                     return false;
169                 }
170                 // Do the addition using size_t to avoid potential wrap-around problems.
171                 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
172                     LOG(ERROR) << "Operand " << index
173                                << ": OperandValue location out of range.  Starts at "
174                                << location.offset << ", length " << location.length << ", max "
175                                << operandValues.size();
176                     return false;
177                 }
178                 break;
179             case OperandLifeTime::CONSTANT_REFERENCE:
180                 if (!poolVerifier.validate(location)) {
181                     return false;
182                 }
183                 break;
184             case OperandLifeTime::TEMPORARY_VARIABLE:
185             case OperandLifeTime::MODEL_INPUT:
186             case OperandLifeTime::MODEL_OUTPUT:
187             case OperandLifeTime::NO_VALUE:
188                 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
189                     LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
190                                << location.poolIndex << ", offset " << location.offset
191                                << ", or length " << location.length << " for operand of lifetime "
192                                << toString(operand.lifetime);
193                     return false;
194                 }
195                 break;
196             default:
197                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
198                            << toString(operand.lifetime);
199                 return false;
200         }
201 
202         // For constants, validate that the length is as expected. The other lifetimes
203         // expect the length to be 0. Don't validate for OEM types.
204         if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
205             operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
206             if (operand.type != OperandType::OEM &&
207                 operand.type != OperandType::TENSOR_OEM_BYTE) {
208                 uint32_t expectedLength = sizeOfData(operand.type, operand.dimensions);
209                 if (location.length != expectedLength) {
210                     LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
211                                << " expected a size of " << expectedLength << " but got "
212                                << location.length;
213                     return false;
214                 }
215             }
216         }
217 
218         index++;
219     }
220     return true;
221 }
222 
validOperationType(V1_0::OperationType operation)223 static bool validOperationType(V1_0::OperationType operation) {
224     switch (operation) {
225         case V1_0::OperationType::ADD:
226         case V1_0::OperationType::AVERAGE_POOL_2D:
227         case V1_0::OperationType::CONCATENATION:
228         case V1_0::OperationType::CONV_2D:
229         case V1_0::OperationType::DEPTHWISE_CONV_2D:
230         case V1_0::OperationType::DEPTH_TO_SPACE:
231         case V1_0::OperationType::DEQUANTIZE:
232         case V1_0::OperationType::EMBEDDING_LOOKUP:
233         case V1_0::OperationType::FLOOR:
234         case V1_0::OperationType::FULLY_CONNECTED:
235         case V1_0::OperationType::HASHTABLE_LOOKUP:
236         case V1_0::OperationType::L2_NORMALIZATION:
237         case V1_0::OperationType::L2_POOL_2D:
238         case V1_0::OperationType::LOCAL_RESPONSE_NORMALIZATION:
239         case V1_0::OperationType::LOGISTIC:
240         case V1_0::OperationType::LSH_PROJECTION:
241         case V1_0::OperationType::LSTM:
242         case V1_0::OperationType::MAX_POOL_2D:
243         case V1_0::OperationType::MUL:
244         case V1_0::OperationType::RELU:
245         case V1_0::OperationType::RELU1:
246         case V1_0::OperationType::RELU6:
247         case V1_0::OperationType::RESHAPE:
248         case V1_0::OperationType::RESIZE_BILINEAR:
249         case V1_0::OperationType::RNN:
250         case V1_0::OperationType::SOFTMAX:
251         case V1_0::OperationType::SPACE_TO_DEPTH:
252         case V1_0::OperationType::SVDF:
253         case V1_0::OperationType::TANH:
254         case V1_0::OperationType::OEM_OPERATION:
255             return true;
256         default:
257             return false;
258     }
259 }
260 
validOperationType(V1_1::OperationType operation)261 static bool validOperationType(V1_1::OperationType operation) {
262     switch (operation) {
263         case V1_1::OperationType::ADD:
264         case V1_1::OperationType::AVERAGE_POOL_2D:
265         case V1_1::OperationType::CONCATENATION:
266         case V1_1::OperationType::CONV_2D:
267         case V1_1::OperationType::DEPTHWISE_CONV_2D:
268         case V1_1::OperationType::DEPTH_TO_SPACE:
269         case V1_1::OperationType::DEQUANTIZE:
270         case V1_1::OperationType::EMBEDDING_LOOKUP:
271         case V1_1::OperationType::FLOOR:
272         case V1_1::OperationType::FULLY_CONNECTED:
273         case V1_1::OperationType::HASHTABLE_LOOKUP:
274         case V1_1::OperationType::L2_NORMALIZATION:
275         case V1_1::OperationType::L2_POOL_2D:
276         case V1_1::OperationType::LOCAL_RESPONSE_NORMALIZATION:
277         case V1_1::OperationType::LOGISTIC:
278         case V1_1::OperationType::LSH_PROJECTION:
279         case V1_1::OperationType::LSTM:
280         case V1_1::OperationType::MAX_POOL_2D:
281         case V1_1::OperationType::MUL:
282         case V1_1::OperationType::RELU:
283         case V1_1::OperationType::RELU1:
284         case V1_1::OperationType::RELU6:
285         case V1_1::OperationType::RESHAPE:
286         case V1_1::OperationType::RESIZE_BILINEAR:
287         case V1_1::OperationType::RNN:
288         case V1_1::OperationType::SOFTMAX:
289         case V1_1::OperationType::SPACE_TO_DEPTH:
290         case V1_1::OperationType::SVDF:
291         case V1_1::OperationType::TANH:
292         case V1_1::OperationType::BATCH_TO_SPACE_ND:
293         case V1_1::OperationType::DIV:
294         case V1_1::OperationType::MEAN:
295         case V1_1::OperationType::PAD:
296         case V1_1::OperationType::SPACE_TO_BATCH_ND:
297         case V1_1::OperationType::SQUEEZE:
298         case V1_1::OperationType::STRIDED_SLICE:
299         case V1_1::OperationType::SUB:
300         case V1_1::OperationType::TRANSPOSE:
301         case V1_1::OperationType::OEM_OPERATION:
302             return true;
303         default:
304             return false;
305     }
306 }
307 
308 template<typename VersionedOperation>
validateOperations(const hidl_vec<VersionedOperation> & operations,const hidl_vec<Operand> & operands)309 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
310                                const hidl_vec<Operand>& operands) {
311     const size_t operandCount = operands.size();
312     // This vector keeps track of whether there's an operation that writes to
313     // each operand. It is used to validate that temporary variables and
314     // model outputs will be written to.
315     std::vector<bool> writtenTo(operandCount, false);
316     for (auto& op : operations) {
317         if (!validOperationType(op.type)) {
318             LOG(ERROR) << "Invalid operation type " << toString(op.type);
319             return false;
320         }
321         // TODO Validate the shapes and any known values. This is currently
322         // done in CpuExecutor but should be done here for all drivers.
323         int error =
324             validateOperation(static_cast<int32_t>(op.type), op.inputs.size(),
325                               op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
326                               op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands);
327         if (error != ANEURALNETWORKS_NO_ERROR) {
328             return false;
329         }
330 
331         for (uint32_t i : op.outputs) {
332             const Operand& operand = operands[i];
333             if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
334                 operand.lifetime != OperandLifeTime::MODEL_OUTPUT) {
335                 LOG(ERROR) << "Writing to an operand with incompatible lifetime "
336                            << toString(operand.lifetime);
337                 return false;
338             }
339 
340             // Check that we only write once to an operand.
341             if (writtenTo[i]) {
342                 LOG(ERROR) << "Operand " << i << " written a second time";
343                 return false;
344             }
345             writtenTo[i] = true;
346         }
347     }
348     for (size_t i = 0; i < operandCount; i++) {
349         if (!writtenTo[i]) {
350             const Operand& operand = operands[i];
351             if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
352                 operand.lifetime == OperandLifeTime::MODEL_OUTPUT) {
353                 LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime)
354                            << " is not being written to.";
355                 return false;
356             }
357         }
358     }
359     // TODO More whole graph verifications are possible, for example that an
360     // operand is not use as input & output for the same op, and more
361     // generally that it is acyclic.
362     return true;
363 }
364 
validatePools(const hidl_vec<hidl_memory> & pools)365 static bool validatePools(const hidl_vec<hidl_memory>& pools) {
366     for (const hidl_memory& memory : pools) {
367         const auto name = memory.name();
368         if (name != "ashmem" && name != "mmap_fd") {
369             LOG(ERROR) << "Unsupported memory type " << name;
370             return false;
371         }
372         if (memory.handle() == nullptr) {
373             LOG(ERROR) << "Memory of type " << name << " is null";
374             return false;
375         }
376     }
377     return true;
378 }
379 
validateModelInputOutputs(const hidl_vec<uint32_t> indexes,const hidl_vec<Operand> & operands,OperandLifeTime lifetime)380 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
381                                       const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
382     const size_t operandCount = operands.size();
383     for (uint32_t i : indexes) {
384         if (i >= operandCount) {
385             LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
386             return false;
387         }
388         const Operand& operand = operands[i];
389         if (operand.lifetime != lifetime) {
390             LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime)
391                        << " instead of the expected " << toString(lifetime);
392             return false;
393         }
394     }
395 
396     std::vector<uint32_t> sortedIndexes = indexes;
397     std::sort(sortedIndexes.begin(), sortedIndexes.end());
398     auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
399     if (adjacentI != sortedIndexes.end()) {
400         LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
401         return false;
402     }
403     return true;
404 }
405 
406 template<typename VersionedModel>
validateModelVersioned(const VersionedModel & model)407 static bool validateModelVersioned(const VersionedModel& model) {
408     return (validateOperands(model.operands, model.operandValues, model.pools) &&
409             validateOperations(model.operations, model.operands) &&
410             validateModelInputOutputs(model.inputIndexes, model.operands,
411                                       OperandLifeTime::MODEL_INPUT) &&
412             validateModelInputOutputs(model.outputIndexes, model.operands,
413                                       OperandLifeTime::MODEL_OUTPUT) &&
414             validatePools(model.pools));
415 }
416 
validateModel(const V1_0::Model & model)417 bool validateModel(const V1_0::Model& model) {
418     return validateModelVersioned(model);
419 }
420 
validateModel(const V1_1::Model & model)421 bool validateModel(const V1_1::Model& model) {
422     return validateModelVersioned(model);
423 }
424 
425 // Validates the arguments of a request. type is either "input" or "output" and is used
426 // for printing error messages. The operandIndexes is the appropriate array of input
427 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hidl_vec<RequestArgument> & requestArguments,const hidl_vec<uint32_t> & operandIndexes,const hidl_vec<Operand> & operands,const hidl_vec<hidl_memory> & pools,const char * type)428 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
429                                      const hidl_vec<uint32_t>& operandIndexes,
430                                      const hidl_vec<Operand>& operands,
431                                      const hidl_vec<hidl_memory>& pools, const char* type) {
432     MemoryAccessVerifier poolVerifier(pools);
433     // The request should specify as many arguments as were described in the model.
434     const size_t requestArgumentCount = requestArguments.size();
435     if (requestArgumentCount != operandIndexes.size()) {
436         LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
437                    << "s but the model has " << operandIndexes.size();
438         return false;
439     }
440     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
441          requestArgumentIndex++) {
442         const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
443         const DataLocation& location = requestArgument.location;
444         // Get the operand index for this argument. We extract it from the list
445         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
446         // We assume in this function that the model has been validated already.
447         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
448         const Operand& operand = operands[operandIndex];
449         if (requestArgument.hasNoValue) {
450             if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
451                 requestArgument.dimensions.size() != 0) {
452                 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
453                            << " has no value yet has details.";
454                 return false;
455             }
456         } else {
457             // Validate the location.
458             if (!poolVerifier.validate(location)) {
459                 return false;
460             }
461             // If the argument specified a dimension, validate it.
462             uint32_t rank = requestArgument.dimensions.size();
463             if (rank == 0) {
464                 // Validate that all the dimensions are specified in the model.
465                 for (size_t i = 0; i < operand.dimensions.size(); i++) {
466                     if (operand.dimensions[i] == 0) {
467                         LOG(ERROR) << "Model has dimension " << i
468                                    << " set to 0 but the request does specify the dimension.";
469                         return false;
470                     }
471                 }
472             } else {
473                 if (rank != operand.dimensions.size()) {
474                     LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
475                                << " has number of dimensions (" << rank
476                                << ") different than the model's (" << operand.dimensions.size()
477                                << ")";
478                     return false;
479                 }
480                 for (size_t i = 0; i < rank; i++) {
481                     if (requestArgument.dimensions[i] != operand.dimensions[i] &&
482                         operand.dimensions[i] != 0) {
483                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
484                                    << " has dimension " << i << " of "
485                                    << requestArgument.dimensions[i]
486                                    << " different than the model's " << operand.dimensions[i];
487                         return false;
488                     }
489                     if (requestArgument.dimensions[i] == 0) {
490                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
491                                    << " has dimension " << i << " of zero";
492                         return false;
493                     }
494                 }
495             }
496         }
497     }
498     return true;
499 }
500 
501 template<typename VersionedModel>
validateRequestVersioned(const Request & request,const VersionedModel & model)502 static bool validateRequestVersioned(const Request& request, const VersionedModel& model) {
503     return (validateRequestArguments(request.inputs, model.inputIndexes, model.operands,
504                                      request.pools, "input") &&
505             validateRequestArguments(request.outputs, model.outputIndexes, model.operands,
506                                      request.pools, "output") &&
507             validatePools(request.pools));
508 }
509 
validateRequest(const Request & request,const V1_0::Model & model)510 bool validateRequest(const Request& request, const V1_0::Model& model) {
511     return validateRequestVersioned(request, model);
512 }
513 
validateRequest(const Request & request,const V1_1::Model & model)514 bool validateRequest(const Request& request, const V1_1::Model& model) {
515     return validateRequestVersioned(request, model);
516 }
517 
validateExecutionPreference(ExecutionPreference preference)518 bool validateExecutionPreference(ExecutionPreference preference) {
519     return preference == ExecutionPreference::LOW_POWER ||
520            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
521            preference == ExecutionPreference::SUSTAINED_SPEED;
522 }
523 
524 }  // namespace nn
525 }  // namespace android
526