1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ValidateHal"
18
19 #include "ValidateHal.h"
20 #include "NeuralNetworks.h"
21 #include "Utils.h"
22
23 #include <android-base/logging.h>
24
25 namespace android {
26 namespace nn {
27
28 class MemoryAccessVerifier {
29 public:
MemoryAccessVerifier(const hidl_vec<hidl_memory> & pools)30 MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
31 : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
32 for (size_t i = 0; i < mPoolCount; i++) {
33 mPoolSizes[i] = pools[i].size();
34 }
35 }
validate(const DataLocation & location)36 bool validate(const DataLocation& location) {
37 if (location.poolIndex >= mPoolCount) {
38 LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
39 return false;
40 }
41 const size_t size = mPoolSizes[location.poolIndex];
42 // Do the addition using size_t to avoid potential wrap-around problems.
43 if (static_cast<size_t>(location.offset) + location.length > size) {
44 LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
45 << location.offset << " and length " << location.length
46 << " exceeds pool size of " << size;
47 return false;
48 }
49 return true;
50 }
51
52 private:
53 size_t mPoolCount;
54 std::vector<size_t> mPoolSizes;
55 };
56
validateOperands(const hidl_vec<Operand> & operands,const hidl_vec<uint8_t> & operandValues,const hidl_vec<hidl_memory> & pools)57 static bool validateOperands(const hidl_vec<Operand>& operands,
58 const hidl_vec<uint8_t>& operandValues,
59 const hidl_vec<hidl_memory>& pools) {
60 uint32_t index = 0;
61 MemoryAccessVerifier poolVerifier(pools);
62 for (auto& operand : operands) {
63 // Validate type and dimensions.
64 switch (operand.type) {
65 case OperandType::FLOAT32:
66 case OperandType::INT32:
67 case OperandType::UINT32:
68 case OperandType::OEM: {
69 size_t count = operand.dimensions.size();
70 if (count != 0) {
71 LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
72 << count;
73 return false;
74 }
75 break;
76 }
77 case OperandType::TENSOR_FLOAT32:
78 case OperandType::TENSOR_INT32:
79 case OperandType::TENSOR_QUANT8_ASYMM:
80 case OperandType::TENSOR_OEM_BYTE: {
81 if (operand.dimensions.size() == 0) {
82 LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
83 return false;
84 }
85 break;
86 }
87 default:
88 LOG(ERROR) << "Operand " << index << ": Invalid operand type "
89 << toString(operand.type);
90 return false;
91 }
92
93 // TODO Validate the numberOfConsumers.
94 // TODO Since we have to validate it, there was no point in including it. For the next
95 // release, consider removing unless we have an additional process in system space
96 // that creates this value. In that case, it would not have to be validated.
97
98 // Validate the scale.
99 switch (operand.type) {
100 case OperandType::FLOAT32:
101 case OperandType::INT32:
102 case OperandType::UINT32:
103 case OperandType::TENSOR_FLOAT32:
104 if (operand.scale != 0.f) {
105 LOG(ERROR) << "Operand " << index << ": Operand of type "
106 << getOperandTypeName(operand.type) << " with a non-zero scale ("
107 << operand.scale << ")";
108 return false;
109 }
110 break;
111 case OperandType::TENSOR_INT32:
112 // TENSOR_INT32 may be used with or without scale, depending on the operation.
113 if (operand.scale < 0.f) {
114 LOG(ERROR) << "Operand " << index << ": Operand of type "
115 << getOperandTypeName(operand.type) << " with a negative scale";
116 return false;
117 }
118 break;
119 case OperandType::TENSOR_QUANT8_ASYMM:
120 if (operand.scale <= 0.f) {
121 LOG(ERROR) << "Operand " << index << ": Operand of type "
122 << getOperandTypeName(operand.type) << " with a non-positive scale";
123 return false;
124 }
125 break;
126 default:
127 // No validation for the OEM types.
128 // TODO We should have had a separate type for TENSOR_INT32 that a scale
129 // and those who don't. Document now and fix in the next release.
130 break;
131 }
132
133 // Validate the zeroPoint.
134 switch (operand.type) {
135 case OperandType::FLOAT32:
136 case OperandType::INT32:
137 case OperandType::UINT32:
138 case OperandType::TENSOR_FLOAT32:
139 case OperandType::TENSOR_INT32:
140 if (operand.zeroPoint != 0) {
141 LOG(ERROR) << "Operand " << index << ": Operand of type "
142 << getOperandTypeName(operand.type) << " with an non-zero zeroPoint "
143 << operand.zeroPoint;
144 return false;
145 }
146 break;
147 case OperandType::TENSOR_QUANT8_ASYMM:
148 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
149 LOG(ERROR) << "Operand " << index << ": Operand of type "
150 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
151 << operand.zeroPoint << ", must be in range [0, 255]";
152 return false;
153 }
154 break;
155 default:
156 // No validation for the OEM types.
157 break;
158 }
159
160 // Validate the lifetime and the location.
161 const DataLocation& location = operand.location;
162 switch (operand.lifetime) {
163 case OperandLifeTime::CONSTANT_COPY:
164 if (location.poolIndex != 0) {
165 LOG(ERROR) << "Operand " << index
166 << ": CONSTANT_COPY with a non-zero poolIndex "
167 << location.poolIndex;
168 return false;
169 }
170 // Do the addition using size_t to avoid potential wrap-around problems.
171 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
172 LOG(ERROR) << "Operand " << index
173 << ": OperandValue location out of range. Starts at "
174 << location.offset << ", length " << location.length << ", max "
175 << operandValues.size();
176 return false;
177 }
178 break;
179 case OperandLifeTime::CONSTANT_REFERENCE:
180 if (!poolVerifier.validate(location)) {
181 return false;
182 }
183 break;
184 case OperandLifeTime::TEMPORARY_VARIABLE:
185 case OperandLifeTime::MODEL_INPUT:
186 case OperandLifeTime::MODEL_OUTPUT:
187 case OperandLifeTime::NO_VALUE:
188 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
189 LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
190 << location.poolIndex << ", offset " << location.offset
191 << ", or length " << location.length << " for operand of lifetime "
192 << toString(operand.lifetime);
193 return false;
194 }
195 break;
196 default:
197 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
198 << toString(operand.lifetime);
199 return false;
200 }
201
202 // For constants, validate that the length is as expected. The other lifetimes
203 // expect the length to be 0. Don't validate for OEM types.
204 if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
205 operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
206 if (operand.type != OperandType::OEM &&
207 operand.type != OperandType::TENSOR_OEM_BYTE) {
208 uint32_t expectedLength = sizeOfData(operand.type, operand.dimensions);
209 if (location.length != expectedLength) {
210 LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
211 << " expected a size of " << expectedLength << " but got "
212 << location.length;
213 return false;
214 }
215 }
216 }
217
218 index++;
219 }
220 return true;
221 }
222
validOperationType(V1_0::OperationType operation)223 static bool validOperationType(V1_0::OperationType operation) {
224 switch (operation) {
225 case V1_0::OperationType::ADD:
226 case V1_0::OperationType::AVERAGE_POOL_2D:
227 case V1_0::OperationType::CONCATENATION:
228 case V1_0::OperationType::CONV_2D:
229 case V1_0::OperationType::DEPTHWISE_CONV_2D:
230 case V1_0::OperationType::DEPTH_TO_SPACE:
231 case V1_0::OperationType::DEQUANTIZE:
232 case V1_0::OperationType::EMBEDDING_LOOKUP:
233 case V1_0::OperationType::FLOOR:
234 case V1_0::OperationType::FULLY_CONNECTED:
235 case V1_0::OperationType::HASHTABLE_LOOKUP:
236 case V1_0::OperationType::L2_NORMALIZATION:
237 case V1_0::OperationType::L2_POOL_2D:
238 case V1_0::OperationType::LOCAL_RESPONSE_NORMALIZATION:
239 case V1_0::OperationType::LOGISTIC:
240 case V1_0::OperationType::LSH_PROJECTION:
241 case V1_0::OperationType::LSTM:
242 case V1_0::OperationType::MAX_POOL_2D:
243 case V1_0::OperationType::MUL:
244 case V1_0::OperationType::RELU:
245 case V1_0::OperationType::RELU1:
246 case V1_0::OperationType::RELU6:
247 case V1_0::OperationType::RESHAPE:
248 case V1_0::OperationType::RESIZE_BILINEAR:
249 case V1_0::OperationType::RNN:
250 case V1_0::OperationType::SOFTMAX:
251 case V1_0::OperationType::SPACE_TO_DEPTH:
252 case V1_0::OperationType::SVDF:
253 case V1_0::OperationType::TANH:
254 case V1_0::OperationType::OEM_OPERATION:
255 return true;
256 default:
257 return false;
258 }
259 }
260
validOperationType(V1_1::OperationType operation)261 static bool validOperationType(V1_1::OperationType operation) {
262 switch (operation) {
263 case V1_1::OperationType::ADD:
264 case V1_1::OperationType::AVERAGE_POOL_2D:
265 case V1_1::OperationType::CONCATENATION:
266 case V1_1::OperationType::CONV_2D:
267 case V1_1::OperationType::DEPTHWISE_CONV_2D:
268 case V1_1::OperationType::DEPTH_TO_SPACE:
269 case V1_1::OperationType::DEQUANTIZE:
270 case V1_1::OperationType::EMBEDDING_LOOKUP:
271 case V1_1::OperationType::FLOOR:
272 case V1_1::OperationType::FULLY_CONNECTED:
273 case V1_1::OperationType::HASHTABLE_LOOKUP:
274 case V1_1::OperationType::L2_NORMALIZATION:
275 case V1_1::OperationType::L2_POOL_2D:
276 case V1_1::OperationType::LOCAL_RESPONSE_NORMALIZATION:
277 case V1_1::OperationType::LOGISTIC:
278 case V1_1::OperationType::LSH_PROJECTION:
279 case V1_1::OperationType::LSTM:
280 case V1_1::OperationType::MAX_POOL_2D:
281 case V1_1::OperationType::MUL:
282 case V1_1::OperationType::RELU:
283 case V1_1::OperationType::RELU1:
284 case V1_1::OperationType::RELU6:
285 case V1_1::OperationType::RESHAPE:
286 case V1_1::OperationType::RESIZE_BILINEAR:
287 case V1_1::OperationType::RNN:
288 case V1_1::OperationType::SOFTMAX:
289 case V1_1::OperationType::SPACE_TO_DEPTH:
290 case V1_1::OperationType::SVDF:
291 case V1_1::OperationType::TANH:
292 case V1_1::OperationType::BATCH_TO_SPACE_ND:
293 case V1_1::OperationType::DIV:
294 case V1_1::OperationType::MEAN:
295 case V1_1::OperationType::PAD:
296 case V1_1::OperationType::SPACE_TO_BATCH_ND:
297 case V1_1::OperationType::SQUEEZE:
298 case V1_1::OperationType::STRIDED_SLICE:
299 case V1_1::OperationType::SUB:
300 case V1_1::OperationType::TRANSPOSE:
301 case V1_1::OperationType::OEM_OPERATION:
302 return true;
303 default:
304 return false;
305 }
306 }
307
308 template<typename VersionedOperation>
validateOperations(const hidl_vec<VersionedOperation> & operations,const hidl_vec<Operand> & operands)309 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
310 const hidl_vec<Operand>& operands) {
311 const size_t operandCount = operands.size();
312 // This vector keeps track of whether there's an operation that writes to
313 // each operand. It is used to validate that temporary variables and
314 // model outputs will be written to.
315 std::vector<bool> writtenTo(operandCount, false);
316 for (auto& op : operations) {
317 if (!validOperationType(op.type)) {
318 LOG(ERROR) << "Invalid operation type " << toString(op.type);
319 return false;
320 }
321 // TODO Validate the shapes and any known values. This is currently
322 // done in CpuExecutor but should be done here for all drivers.
323 int error =
324 validateOperation(static_cast<int32_t>(op.type), op.inputs.size(),
325 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
326 op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands);
327 if (error != ANEURALNETWORKS_NO_ERROR) {
328 return false;
329 }
330
331 for (uint32_t i : op.outputs) {
332 const Operand& operand = operands[i];
333 if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
334 operand.lifetime != OperandLifeTime::MODEL_OUTPUT) {
335 LOG(ERROR) << "Writing to an operand with incompatible lifetime "
336 << toString(operand.lifetime);
337 return false;
338 }
339
340 // Check that we only write once to an operand.
341 if (writtenTo[i]) {
342 LOG(ERROR) << "Operand " << i << " written a second time";
343 return false;
344 }
345 writtenTo[i] = true;
346 }
347 }
348 for (size_t i = 0; i < operandCount; i++) {
349 if (!writtenTo[i]) {
350 const Operand& operand = operands[i];
351 if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
352 operand.lifetime == OperandLifeTime::MODEL_OUTPUT) {
353 LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime)
354 << " is not being written to.";
355 return false;
356 }
357 }
358 }
359 // TODO More whole graph verifications are possible, for example that an
360 // operand is not use as input & output for the same op, and more
361 // generally that it is acyclic.
362 return true;
363 }
364
validatePools(const hidl_vec<hidl_memory> & pools)365 static bool validatePools(const hidl_vec<hidl_memory>& pools) {
366 for (const hidl_memory& memory : pools) {
367 const auto name = memory.name();
368 if (name != "ashmem" && name != "mmap_fd") {
369 LOG(ERROR) << "Unsupported memory type " << name;
370 return false;
371 }
372 if (memory.handle() == nullptr) {
373 LOG(ERROR) << "Memory of type " << name << " is null";
374 return false;
375 }
376 }
377 return true;
378 }
379
validateModelInputOutputs(const hidl_vec<uint32_t> indexes,const hidl_vec<Operand> & operands,OperandLifeTime lifetime)380 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
381 const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
382 const size_t operandCount = operands.size();
383 for (uint32_t i : indexes) {
384 if (i >= operandCount) {
385 LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
386 return false;
387 }
388 const Operand& operand = operands[i];
389 if (operand.lifetime != lifetime) {
390 LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime)
391 << " instead of the expected " << toString(lifetime);
392 return false;
393 }
394 }
395
396 std::vector<uint32_t> sortedIndexes = indexes;
397 std::sort(sortedIndexes.begin(), sortedIndexes.end());
398 auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
399 if (adjacentI != sortedIndexes.end()) {
400 LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
401 return false;
402 }
403 return true;
404 }
405
406 template<typename VersionedModel>
validateModelVersioned(const VersionedModel & model)407 static bool validateModelVersioned(const VersionedModel& model) {
408 return (validateOperands(model.operands, model.operandValues, model.pools) &&
409 validateOperations(model.operations, model.operands) &&
410 validateModelInputOutputs(model.inputIndexes, model.operands,
411 OperandLifeTime::MODEL_INPUT) &&
412 validateModelInputOutputs(model.outputIndexes, model.operands,
413 OperandLifeTime::MODEL_OUTPUT) &&
414 validatePools(model.pools));
415 }
416
validateModel(const V1_0::Model & model)417 bool validateModel(const V1_0::Model& model) {
418 return validateModelVersioned(model);
419 }
420
validateModel(const V1_1::Model & model)421 bool validateModel(const V1_1::Model& model) {
422 return validateModelVersioned(model);
423 }
424
425 // Validates the arguments of a request. type is either "input" or "output" and is used
426 // for printing error messages. The operandIndexes is the appropriate array of input
427 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hidl_vec<RequestArgument> & requestArguments,const hidl_vec<uint32_t> & operandIndexes,const hidl_vec<Operand> & operands,const hidl_vec<hidl_memory> & pools,const char * type)428 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
429 const hidl_vec<uint32_t>& operandIndexes,
430 const hidl_vec<Operand>& operands,
431 const hidl_vec<hidl_memory>& pools, const char* type) {
432 MemoryAccessVerifier poolVerifier(pools);
433 // The request should specify as many arguments as were described in the model.
434 const size_t requestArgumentCount = requestArguments.size();
435 if (requestArgumentCount != operandIndexes.size()) {
436 LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
437 << "s but the model has " << operandIndexes.size();
438 return false;
439 }
440 for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
441 requestArgumentIndex++) {
442 const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
443 const DataLocation& location = requestArgument.location;
444 // Get the operand index for this argument. We extract it from the list
445 // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
446 // We assume in this function that the model has been validated already.
447 const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
448 const Operand& operand = operands[operandIndex];
449 if (requestArgument.hasNoValue) {
450 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
451 requestArgument.dimensions.size() != 0) {
452 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
453 << " has no value yet has details.";
454 return false;
455 }
456 } else {
457 // Validate the location.
458 if (!poolVerifier.validate(location)) {
459 return false;
460 }
461 // If the argument specified a dimension, validate it.
462 uint32_t rank = requestArgument.dimensions.size();
463 if (rank == 0) {
464 // Validate that all the dimensions are specified in the model.
465 for (size_t i = 0; i < operand.dimensions.size(); i++) {
466 if (operand.dimensions[i] == 0) {
467 LOG(ERROR) << "Model has dimension " << i
468 << " set to 0 but the request does specify the dimension.";
469 return false;
470 }
471 }
472 } else {
473 if (rank != operand.dimensions.size()) {
474 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
475 << " has number of dimensions (" << rank
476 << ") different than the model's (" << operand.dimensions.size()
477 << ")";
478 return false;
479 }
480 for (size_t i = 0; i < rank; i++) {
481 if (requestArgument.dimensions[i] != operand.dimensions[i] &&
482 operand.dimensions[i] != 0) {
483 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
484 << " has dimension " << i << " of "
485 << requestArgument.dimensions[i]
486 << " different than the model's " << operand.dimensions[i];
487 return false;
488 }
489 if (requestArgument.dimensions[i] == 0) {
490 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
491 << " has dimension " << i << " of zero";
492 return false;
493 }
494 }
495 }
496 }
497 }
498 return true;
499 }
500
501 template<typename VersionedModel>
validateRequestVersioned(const Request & request,const VersionedModel & model)502 static bool validateRequestVersioned(const Request& request, const VersionedModel& model) {
503 return (validateRequestArguments(request.inputs, model.inputIndexes, model.operands,
504 request.pools, "input") &&
505 validateRequestArguments(request.outputs, model.outputIndexes, model.operands,
506 request.pools, "output") &&
507 validatePools(request.pools));
508 }
509
validateRequest(const Request & request,const V1_0::Model & model)510 bool validateRequest(const Request& request, const V1_0::Model& model) {
511 return validateRequestVersioned(request, model);
512 }
513
validateRequest(const Request & request,const V1_1::Model & model)514 bool validateRequest(const Request& request, const V1_1::Model& model) {
515 return validateRequestVersioned(request, model);
516 }
517
validateExecutionPreference(ExecutionPreference preference)518 bool validateExecutionPreference(ExecutionPreference preference) {
519 return preference == ExecutionPreference::LOW_POWER ||
520 preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
521 preference == ExecutionPreference::SUSTAINED_SPEED;
522 }
523
524 } // namespace nn
525 } // namespace android
526