1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ValidateHal"
18
19 #include "ValidateHal.h"
20
21 #include <android-base/logging.h>
22
23 #include <algorithm>
24 #include <set>
25 #include <utility>
26 #include <vector>
27
28 #include "NeuralNetworks.h"
29 #include "OperationsUtils.h"
30 #include "Tracing.h"
31 #include "Utils.h"
32
33 namespace android {
34 namespace nn {
35
36 using namespace hal;
37
38 template <class T_Model>
39 struct ModelToHalVersion;
40 template <>
41 struct ModelToHalVersion<V1_0::Model> {
42 static constexpr HalVersion version = HalVersion::V1_0;
43 };
44 template <>
45 struct ModelToHalVersion<V1_1::Model> {
46 static constexpr HalVersion version = HalVersion::V1_1;
47 };
48 template <>
49 struct ModelToHalVersion<V1_2::Model> {
50 static constexpr HalVersion version = HalVersion::V1_2;
51 };
52 template <>
53 struct ModelToHalVersion<V1_3::Model> {
54 static constexpr HalVersion version = HalVersion::V1_3;
55 };
56
57 class MemoryAccessVerifier {
58 public:
MemoryAccessVerifier(const hidl_vec<hidl_memory> & pools)59 MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
60 : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
61 for (size_t i = 0; i < mPoolCount; i++) {
62 mPoolSizes[i] = pools[i].size();
63 }
64 }
MemoryAccessVerifier(const hidl_vec<V1_3::Request::MemoryPool> & pools)65 MemoryAccessVerifier(const hidl_vec<V1_3::Request::MemoryPool>& pools)
66 : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
67 for (size_t i = 0; i < mPoolCount; i++) {
68 switch (pools[i].getDiscriminator()) {
69 case Request::MemoryPool::hidl_discriminator::hidlMemory:
70 mPoolSizes[i] = pools[i].hidlMemory().size();
71 break;
72 case Request::MemoryPool::hidl_discriminator::token:
73 // Set size to 0 to enforce length == 0 && offset == 0.
74 mPoolSizes[i] = 0;
75 break;
76 }
77 }
78 }
validate(const DataLocation & location) const79 bool validate(const DataLocation& location) const {
80 if (location.poolIndex >= mPoolCount) {
81 LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
82 return false;
83 }
84 const size_t size = mPoolSizes[location.poolIndex];
85 // Do the addition using size_t to avoid potential wrap-around problems.
86 if (static_cast<size_t>(location.offset) + location.length > size) {
87 LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
88 << location.offset << " and length " << location.length
89 << " exceeds pool size of " << size;
90 return false;
91 }
92 return true;
93 }
94
95 private:
96 size_t mPoolCount;
97 std::vector<size_t> mPoolSizes;
98 };
99
validateOperandExtraParams(const V1_3::Operand & operand,uint32_t index)100 static bool validateOperandExtraParams(const V1_3::Operand& operand, uint32_t index) {
101 switch (operand.type) {
102 case OperandType::FLOAT32:
103 case OperandType::INT32:
104 case OperandType::UINT32:
105 case OperandType::BOOL:
106 case OperandType::SUBGRAPH:
107 case OperandType::TENSOR_FLOAT32:
108 case OperandType::TENSOR_FLOAT16:
109 case OperandType::TENSOR_INT32:
110 case OperandType::TENSOR_QUANT8_ASYMM:
111 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
112 case OperandType::TENSOR_QUANT8_SYMM:
113 case OperandType::TENSOR_QUANT16_ASYMM:
114 case OperandType::TENSOR_QUANT16_SYMM:
115 case OperandType::TENSOR_BOOL8: {
116 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
117 OperandExtraParams::hidl_discriminator::none)
118 << "Operand " << index << ": Operand of type "
119 << getOperandTypeName(operand.type)
120 << " has incorrect extraParams: " << toString(operand.extraParams);
121 } break;
122 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
123 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
124 OperandExtraParams::hidl_discriminator::channelQuant)
125 << "Operand " << index << ": Operand of type "
126 << getOperandTypeName(operand.type) << " without a Channel Quantization params";
127 auto& channelQuant = operand.extraParams.channelQuant();
128
129 size_t count = operand.dimensions.size();
130 NN_RET_CHECK_LT(channelQuant.channelDim, count)
131 << "Operand " << index << ": Operand of type "
132 << getOperandTypeName(operand.type)
133 << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
134 << ", must be valid dimension index in range [0, " << count << ")";
135 uint32_t expected = operand.dimensions[channelQuant.channelDim];
136 NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
137 << "Operand " << index << ": Operand of type "
138 << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
139 << "expected " << expected << " was " << channelQuant.scales.size();
140 NN_RET_CHECK_NE(expected, 0)
141 << "Operand " << index << ": Operand of type "
142 << getOperandTypeName(operand.type) << " channel dimension "
143 << channelQuant.channelDim << " is underspecified (can't be 0)";
144 for (uint32_t i = 0; i < expected; ++i) {
145 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
146 << "Operand " << index << ": Operand of type "
147 << getOperandTypeName(operand.type) << " with a negative value in scales["
148 << i << "]=" << channelQuant.scales[i];
149 }
150 } break;
151 default: {
152 if (isExtensionOperandType(operand.type)) {
153 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
154 OperandExtraParams::hidl_discriminator::extension ||
155 operand.extraParams.getDiscriminator() ==
156 OperandExtraParams::hidl_discriminator::none)
157 << "Operand " << index << ": Extension operand of type "
158 << getOperandTypeName(operand.type)
159 << " has incorrect extraParams: " << toString(operand.extraParams);
160 }
161 // No validation for OEM types.
162 } break;
163 }
164 return true;
165 }
166
167 template <typename VersionedOperand>
validateOperands(const hidl_vec<VersionedOperand> & operands,const hidl_vec<uint8_t> & operandValues,const hidl_vec<hidl_memory> & pools,const hidl_vec<Subgraph> & subgraphs,bool allowUnspecifiedRank)168 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
169 const hidl_vec<uint8_t>& operandValues,
170 const hidl_vec<hidl_memory>& pools,
171 const hidl_vec<Subgraph>& subgraphs, bool allowUnspecifiedRank) {
172 uint32_t index = 0;
173 MemoryAccessVerifier poolVerifier(pools);
174 for (auto& versionedOperand : operands) {
175 if (!validOperandType(versionedOperand.type)) {
176 LOG(ERROR) << "Operand is not supported by this version: "
177 << toString(versionedOperand.type);
178 return false;
179 }
180 // Once we are sure the operand is supported by its version, it is safe
181 // to convert it to the latest version for the rest of the validations.
182 V1_3::Operand operand = convertToV1_3(versionedOperand);
183 // Validate type and dimensions.
184 switch (operand.type) {
185 case OperandType::FLOAT16:
186 case OperandType::FLOAT32:
187 case OperandType::INT32:
188 case OperandType::UINT32:
189 case OperandType::BOOL:
190 case OperandType::SUBGRAPH:
191 case OperandType::OEM: {
192 size_t count = operand.dimensions.size();
193 if (count != 0) {
194 LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
195 << count;
196 return false;
197 }
198 break;
199 }
200 case OperandType::TENSOR_FLOAT16:
201 case OperandType::TENSOR_FLOAT32:
202 case OperandType::TENSOR_INT32:
203 case OperandType::TENSOR_QUANT8_ASYMM:
204 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
205 case OperandType::TENSOR_QUANT8_SYMM:
206 case OperandType::TENSOR_QUANT16_ASYMM:
207 case OperandType::TENSOR_QUANT16_SYMM:
208 case OperandType::TENSOR_BOOL8:
209 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
210 case OperandType::TENSOR_OEM_BYTE: {
211 if ((!allowUnspecifiedRank || operand.lifetime == OperandLifeTime::CONSTANT_COPY ||
212 operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE) &&
213 operand.dimensions.size() == 0) {
214 LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
215 return false;
216 }
217 break;
218 }
219 default: {
220 if (!isExtensionOperandType(operand.type)) {
221 LOG(ERROR) << "Operand " << index << ": Invalid operand type "
222 << toString(operand.type);
223 return false;
224 }
225 } break;
226 }
227
228 // Validate the scale.
229 switch (operand.type) {
230 case OperandType::FLOAT16:
231 case OperandType::FLOAT32:
232 case OperandType::INT32:
233 case OperandType::UINT32:
234 case OperandType::BOOL:
235 case OperandType::SUBGRAPH:
236 case OperandType::TENSOR_FLOAT16:
237 case OperandType::TENSOR_FLOAT32:
238 case OperandType::TENSOR_BOOL8:
239 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
240 if (operand.scale != 0.f) {
241 LOG(ERROR) << "Operand " << index << ": Operand of type "
242 << getOperandTypeName(operand.type) << " with a non-zero scale ("
243 << operand.scale << ")";
244 return false;
245 }
246 break;
247 case OperandType::TENSOR_INT32:
248 // TENSOR_INT32 may be used with or without scale, depending on the operation.
249 if (operand.scale < 0.f) {
250 LOG(ERROR) << "Operand " << index << ": Operand of type "
251 << getOperandTypeName(operand.type) << " with a negative scale";
252 return false;
253 }
254 break;
255 case OperandType::TENSOR_QUANT8_ASYMM:
256 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
257 case OperandType::TENSOR_QUANT8_SYMM:
258 case OperandType::TENSOR_QUANT16_ASYMM:
259 case OperandType::TENSOR_QUANT16_SYMM:
260 if (operand.scale <= 0.f) {
261 LOG(ERROR) << "Operand " << index << ": Operand of type "
262 << getOperandTypeName(operand.type) << " with a non-positive scale";
263 return false;
264 }
265 break;
266 default:
267 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
268 LOG(ERROR) << "Operand " << index << ": Operand of type "
269 << getOperandTypeName(operand.type) << " with a non-zero scale ("
270 << operand.scale << ")";
271 return false;
272 }
273 // No validation for OEM types.
274 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
275 break;
276 }
277
278 // Validate the zeroPoint.
279 switch (operand.type) {
280 case OperandType::FLOAT16:
281 case OperandType::FLOAT32:
282 case OperandType::INT32:
283 case OperandType::UINT32:
284 case OperandType::BOOL:
285 case OperandType::SUBGRAPH:
286 case OperandType::TENSOR_FLOAT16:
287 case OperandType::TENSOR_FLOAT32:
288 case OperandType::TENSOR_INT32:
289 case OperandType::TENSOR_BOOL8:
290 case OperandType::TENSOR_QUANT8_SYMM:
291 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
292 if (operand.zeroPoint != 0) {
293 LOG(ERROR) << "Operand " << index << ": Operand of type "
294 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
295 << operand.zeroPoint;
296 return false;
297 }
298 break;
299 case OperandType::TENSOR_QUANT8_ASYMM:
300 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
301 LOG(ERROR) << "Operand " << index << ": Operand of type "
302 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
303 << operand.zeroPoint << ", must be in range [0, 255]";
304 return false;
305 }
306 break;
307 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
308 if (operand.zeroPoint < -128 || operand.zeroPoint > 127) {
309 LOG(ERROR) << "Operand " << index << ": Operand of type "
310 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
311 << operand.zeroPoint << ", must be in range [-128, 127]";
312 return false;
313 }
314 break;
315 case OperandType::TENSOR_QUANT16_ASYMM:
316 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
317 LOG(ERROR) << "Operand " << index << ": Operand of type "
318 << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
319 << operand.zeroPoint << ", must be in range [0, 65535]";
320 return false;
321 }
322 break;
323 case OperandType::TENSOR_QUANT16_SYMM:
324 if (operand.zeroPoint != 0) {
325 LOG(ERROR) << "Operand " << index << ": Operand of type "
326 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
327 << operand.zeroPoint;
328 return false;
329 }
330 break;
331 default:
332 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
333 LOG(ERROR) << "Operand " << index << ": Operand of type "
334 << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
335 << operand.zeroPoint;
336 return false;
337 }
338 // No validation for OEM types.
339 break;
340 }
341
342 NN_RET_CHECK(validateOperandExtraParams(operand, index));
343
344 // Validate the lifetime and the location.
345 const DataLocation& location = operand.location;
346 switch (operand.lifetime) {
347 case OperandLifeTime::CONSTANT_COPY:
348 if (location.poolIndex != 0) {
349 LOG(ERROR) << "Operand " << index
350 << ": CONSTANT_COPY with a non-zero poolIndex "
351 << location.poolIndex;
352 return false;
353 }
354 // Do the addition using size_t to avoid potential wrap-around problems.
355 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
356 LOG(ERROR) << "Operand " << index
357 << ": OperandValue location out of range. Starts at "
358 << location.offset << ", length " << location.length << ", max "
359 << operandValues.size();
360 return false;
361 }
362 break;
363 case OperandLifeTime::CONSTANT_REFERENCE:
364 if (!poolVerifier.validate(location)) {
365 return false;
366 }
367 break;
368 case OperandLifeTime::TEMPORARY_VARIABLE:
369 case OperandLifeTime::SUBGRAPH_INPUT:
370 case OperandLifeTime::SUBGRAPH_OUTPUT:
371 case OperandLifeTime::NO_VALUE:
372 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
373 LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
374 << location.poolIndex << ", offset " << location.offset
375 << ", or length " << location.length << " for operand of lifetime "
376 << toString(operand.lifetime);
377 return false;
378 }
379 break;
380 case OperandLifeTime::SUBGRAPH: {
381 if (location.poolIndex != 0) {
382 LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
383 << location.poolIndex;
384 return false;
385 }
386 if (location.offset >= subgraphs.size()) {
387 LOG(ERROR) << "Subgraph index out of range: " << location.offset
388 << " >= " << subgraphs.size();
389 return false;
390 }
391 if (location.length != 0) {
392 LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
393 << location.length;
394 return false;
395 }
396 } break;
397 default:
398 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
399 << toString(operand.lifetime);
400 return false;
401 }
402
403 // Make sure SUBGRAPH operand type and lifetime always go together.
404 if ((operand.type == OperandType::SUBGRAPH) !=
405 (operand.lifetime == OperandLifeTime::SUBGRAPH)) {
406 LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
407 << " cannot have lifetime " << toString(operand.lifetime);
408 return false;
409 }
410
411 // For constants, validate that the length is as expected. The other lifetimes
412 // expect the length to be 0. Don't validate for OEM types.
413 if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
414 operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
415 if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM &&
416 operand.type != OperandType::TENSOR_OEM_BYTE) {
417 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
418 if (location.length != expectedLength) {
419 LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
420 << " expected a size of " << expectedLength << " but got "
421 << location.length;
422 return false;
423 }
424 }
425 }
426
427 index++;
428 }
429 return true;
430 }
431
getHalVersion(const V1_0::Operation &)432 static HalVersion getHalVersion(const V1_0::Operation&) {
433 return HalVersion::V1_0;
434 }
435
getHalVersion(const V1_1::Operation &)436 static HalVersion getHalVersion(const V1_1::Operation&) {
437 return HalVersion::V1_1;
438 }
439
getHalVersion(const V1_2::Operation &)440 static HalVersion getHalVersion(const V1_2::Operation&) {
441 return HalVersion::V1_2;
442 }
443
getHalVersion(const V1_3::Operation &)444 static HalVersion getHalVersion(const V1_3::Operation&) {
445 return HalVersion::V1_3;
446 }
447
448 template <typename VersionedOperation>
validateOperations(const hidl_vec<VersionedOperation> & operations,const hidl_vec<Operand> & operands,const hidl_vec<Subgraph> & subgraphs,ValidationMode mode)449 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
450 const hidl_vec<Operand>& operands,
451 const hidl_vec<Subgraph>& subgraphs, ValidationMode mode) {
452 auto isValidSubgraphReference = [&subgraphs](const Operand& modelOperand) -> bool {
453 NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
454 << "Unexpected operand type: " << toString(modelOperand.type);
455 NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size())
456 << "Invalid subgraph reference";
457 return true;
458 };
459 auto getSubgraph = [&subgraphs](const Operand& modelOperand) -> const Subgraph* {
460 CHECK_LT(modelOperand.location.offset, subgraphs.size());
461 return &subgraphs[modelOperand.location.offset];
462 };
463 auto getInputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
464 return getSubgraph(modelOperand)->inputIndexes.size();
465 };
466 auto getOutputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
467 return getSubgraph(modelOperand)->outputIndexes.size();
468 };
469 auto getInputOperand = [&getSubgraph](const Operand& modelOperand,
470 uint32_t index) -> const Operand* {
471 const Subgraph& subgraph = *getSubgraph(modelOperand);
472 CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
473 return &subgraph.operands[subgraph.inputIndexes[index]];
474 };
475 auto getOutputOperand = [&getSubgraph](const Operand& modelOperand,
476 uint32_t index) -> const Operand* {
477 const Subgraph& subgraph = *getSubgraph(modelOperand);
478 CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
479 return &subgraph.operands[subgraph.outputIndexes[index]];
480 };
481 const size_t operandCount = operands.size();
482 for (auto& op : operations) {
483 // TODO Validate the shapes and any known values. This is currently
484 // done in CpuExecutor but should be done here for all drivers.
485 int error = validateOperation(
486 static_cast<int32_t>(op.type), op.inputs.size(),
487 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
488 op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op),
489 {.isValidSubgraphReference = isValidSubgraphReference,
490 .getSubgraphInputCount = getInputCount,
491 .getSubgraphOutputCount = getOutputCount,
492 .getSubgraphInputOperand = getInputOperand,
493 .getSubgraphOutputOperand = getOutputOperand,
494 // 1.3 HAL does not support CF operations with operands of
495 // unknown size. See http://b/132458982#comment63.
496 .allowControlFlowOperationWithOperandOfUnknownSize =
497 mode == ValidationMode::RUNTIME});
498 if (error != ANEURALNETWORKS_NO_ERROR) {
499 LOG(ERROR) << "Invalid operation " << toString(op.type);
500 return false;
501 }
502
503 // This is redundant because of the checks in validateGraph(),
504 // but it is retained here in order to emit more informative
505 // error messages.
506 for (uint32_t i : op.outputs) {
507 const Operand& operand = operands[i];
508 if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
509 operand.lifetime != OperandLifeTime::SUBGRAPH_OUTPUT) {
510 LOG(ERROR) << "Writing to operand " << i << " with incompatible lifetime "
511 << toString(operand.lifetime);
512 return false;
513 }
514 }
515 }
516 return true;
517 }
518
validatePool(const hidl_memory & pool,HalVersion ver)519 bool validatePool(const hidl_memory& pool, HalVersion ver) {
520 const auto& name = pool.name();
521 if (name != "ashmem" && name != "mmap_fd" &&
522 ((ver < HalVersion::V1_2) ||
523 (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
524 LOG(ERROR) << "Unsupported memory type " << name;
525 return false;
526 }
527 if (pool.handle() == nullptr) {
528 LOG(ERROR) << "Memory of type " << name << " is null";
529 return false;
530 }
531 return true;
532 }
533
validatePool(const V1_3::Request::MemoryPool & pool,HalVersion ver)534 bool validatePool(const V1_3::Request::MemoryPool& pool, HalVersion ver) {
535 switch (pool.getDiscriminator()) {
536 case Request::MemoryPool::hidl_discriminator::hidlMemory:
537 return validatePool(pool.hidlMemory(), ver);
538 case Request::MemoryPool::hidl_discriminator::token:
539 return pool.token() > 0;
540 }
541 LOG(FATAL) << "unknown MemoryPool discriminator";
542 return false;
543 }
544
545 template <class T_MemoryPool>
validatePools(const hidl_vec<T_MemoryPool> & pools,HalVersion ver)546 static bool validatePools(const hidl_vec<T_MemoryPool>& pools, HalVersion ver) {
547 return std::all_of(pools.begin(), pools.end(),
548 [ver](const auto& pool) { return validatePool(pool, ver); });
549 }
550
validateModelInputOutputs(const hidl_vec<uint32_t> indexes,const hidl_vec<Operand> & operands,OperandLifeTime lifetime)551 static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
552 const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
553 const size_t operandCount = operands.size();
554 for (uint32_t i : indexes) {
555 if (i >= operandCount) {
556 LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
557 return false;
558 }
559 const Operand& operand = operands[i];
560 if (operand.lifetime != lifetime) {
561 LOG(ERROR) << "Model input or output operand " << i << " has lifetime of "
562 << toString(operand.lifetime) << " instead of the expected "
563 << toString(lifetime);
564 return false;
565 }
566 }
567
568 std::vector<uint32_t> sortedIndexes = indexes;
569 std::sort(sortedIndexes.begin(), sortedIndexes.end());
570 auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
571 if (adjacentI != sortedIndexes.end()) {
572 LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
573 return false;
574 }
575
576 for (size_t i = 0; i < operands.size(); ++i) {
577 if (operands[i].lifetime == lifetime &&
578 !binary_search(sortedIndexes.begin(), sortedIndexes.end(), i)) {
579 LOG(ERROR) << "Operand " << i << " marked as " << toString(lifetime)
580 << " but is not included in Model input or output indexes";
581 return false;
582 }
583 }
584
585 return true;
586 }
587
588 template <typename VersionedModelOrSubgraph>
validateGraph(const VersionedModelOrSubgraph & model)589 static bool validateGraph(const VersionedModelOrSubgraph& model) {
590 // set up counts
591 std::vector<uint32_t> operandNumberOfConsumers(model.operands.size(), 0);
592 // Either the operand has a known value before model execution
593 // begins, or we've seen a writer for this operand while
594 // walking operands in execution order.
595 std::vector<bool> operandValueKnown(model.operands.size(), false);
596
597 // mark known operands
598 for (size_t i = 0; i < model.operands.size(); ++i) {
599 const auto& operand = model.operands[i];
600 const OperandLifeTime lifetime = convertToV1_3(operand.lifetime);
601 operandValueKnown[i] = lifetime == OperandLifeTime::SUBGRAPH_INPUT ||
602 lifetime == OperandLifeTime::CONSTANT_COPY ||
603 lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
604 lifetime == OperandLifeTime::NO_VALUE ||
605 lifetime == OperandLifeTime::SUBGRAPH;
606 }
607
608 // Validate that operations are sorted into execution order.
609 //
610 // If there is a cycle in the graph, the operations will not
611 // appear to be sorted into execution order: Some operation will
612 // have an input for which operandValueKnown[] is false.
613 for (size_t i = 0; i < model.operations.size(); ++i) {
614 const auto& operation = model.operations[i];
615
616 for (size_t j = 0; j < operation.inputs.size(); ++j) {
617 uint32_t k = operation.inputs[j];
618 if (!operandValueKnown[k]) {
619 LOG(ERROR) << "Operation " << i << " input " << j << " (operand " << k
620 << ") is read before it is written";
621 return false;
622 }
623 operandNumberOfConsumers[k]++;
624 }
625
626 for (size_t j = 0; j < operation.outputs.size(); ++j) {
627 uint32_t k = operation.outputs[j];
628 if (operandValueKnown[k]) {
629 // Assuming validateOperations() has returned true, we
630 // know that this output is TEMPORARY_VARIABLE or
631 // MODEL_OUTPUT, and so the only way
632 // operandValueKnown[k] can be true is if we've
633 // already seen a writer for this operand.
634 LOG(ERROR) << "Operation " << i << " output " << j << " (operand " << k
635 << ") has already been written";
636 return false;
637 }
638 operandValueKnown[k] = true;
639 }
640 }
641
642 // validate number of consumers
643 //
644 // TODO Because we have to validate it, there was no point in including it
645 // in struct Operand. For the next release, consider removing unless we have
646 // an additional process in system space that creates this value. In that
647 // case, it would not have to be validated.
648 for (size_t i = 0; i < model.operands.size(); ++i) {
649 if (model.operands[i].numberOfConsumers != operandNumberOfConsumers[i]) {
650 LOG(ERROR) << "Operand " << i << " has incorrect number of consumers "
651 << model.operands[i].numberOfConsumers << ", expected "
652 << operandNumberOfConsumers[i];
653 return false;
654 }
655 }
656
657 // verify all operands are written
658 for (size_t i = 0; i < model.operands.size(); ++i) {
659 if (!operandValueKnown[i]) {
660 LOG(ERROR) << "Operand " << i << " is never written";
661 return false;
662 }
663 }
664
665 return true;
666 }
667
668 // Makes sure the model does not contain subgraph reference cycles.
checkNoReferenceCycles(const V1_3::Model & model,const V1_3::Subgraph & subgraph,std::set<const V1_3::Subgraph * > * path)669 static bool checkNoReferenceCycles(const V1_3::Model& model, const V1_3::Subgraph& subgraph,
670 std::set<const V1_3::Subgraph*>* path) {
671 auto [_, isNew] = path->insert(&subgraph);
672 if (!isNew) {
673 LOG(ERROR) << "Model contains a circular subgraph reference";
674 return false;
675 }
676 for (const Operand& operand : subgraph.operands) {
677 if (operand.lifetime == OperandLifeTime::SUBGRAPH) {
678 uint32_t refSubgraphIndex = operand.location.offset;
679 if (!checkNoReferenceCycles(model, model.referenced[refSubgraphIndex], path)) {
680 return false;
681 }
682 }
683 }
684 path->erase(&subgraph);
685 return true;
686 }
687
checkNoReferenceCycles(const V1_3::Model & model)688 static bool checkNoReferenceCycles(const V1_3::Model& model) {
689 std::set<const V1_3::Subgraph*> path;
690 return checkNoReferenceCycles(model, model.main, &path);
691 }
692
693 template <class T_Model>
validateModel(const T_Model & model,ValidationMode mode)694 bool validateModel(const T_Model& model, ValidationMode mode) {
695 NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
696 HalVersion version = ModelToHalVersion<T_Model>::version;
697 if (model.operations.size() == 0 || model.operands.size() == 0) {
698 LOG(ERROR) << "Invalid empty model.";
699 return false;
700 }
701 // We only need versioned operands for their validation. For all the other
702 // validations we can use operands upcasted to the latest version.
703 const hidl_vec<Operand> latestVersionOperands = convertToV1_3(model.operands);
704 return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
705 /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
706 validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}, mode) &&
707 validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
708 OperandLifeTime::SUBGRAPH_INPUT) &&
709 validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
710 OperandLifeTime::SUBGRAPH_OUTPUT) &&
711 validatePools(model.pools, version) && validateGraph(model));
712 }
713
714 template bool validateModel<V1_0::Model>(const V1_0::Model& model, ValidationMode mode);
715 template bool validateModel<V1_1::Model>(const V1_1::Model& model, ValidationMode mode);
716 template bool validateModel<V1_2::Model>(const V1_2::Model& model, ValidationMode mode);
717
718 template <>
validateModel(const V1_3::Model & model,ValidationMode mode)719 bool validateModel(const V1_3::Model& model, ValidationMode mode) {
720 NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
721 if (model.main.operations.size() == 0 || model.main.operands.size() == 0) {
722 LOG(ERROR) << "Invalid empty model.";
723 return false;
724 }
725 auto validateSubgraph = [&model, mode](const Subgraph& subgraph) -> bool {
726 return (validateOperands(subgraph.operands, model.operandValues, model.pools,
727 model.referenced, /*allowUnspecifiedRank=*/true) &&
728 validateOperations(subgraph.operations, subgraph.operands, model.referenced,
729 mode) &&
730 validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
731 OperandLifeTime::SUBGRAPH_INPUT) &&
732 validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
733 OperandLifeTime::SUBGRAPH_OUTPUT) &&
734 validateGraph(subgraph));
735 };
736 return (validateSubgraph(model.main) &&
737 std::all_of(model.referenced.begin(), model.referenced.end(), validateSubgraph) &&
738 validatePools(model.pools, HalVersion::V1_3) && checkNoReferenceCycles(model));
739 }
740
741 // Validates the arguments of a request. type is either "input" or "output" and is used
742 // for printing error messages. The operandIndexes is the appropriate array of input
743 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hidl_vec<RequestArgument> & requestArguments,const hidl_vec<uint32_t> & operandIndexes,const hidl_vec<Operand> & operands,const MemoryAccessVerifier & poolVerifier,bool allowUnspecified,const char * type)744 static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
745 const hidl_vec<uint32_t>& operandIndexes,
746 const hidl_vec<Operand>& operands,
747 const MemoryAccessVerifier& poolVerifier,
748 bool allowUnspecified, const char* type) {
749 // The request should specify as many arguments as were described in the model.
750 const size_t requestArgumentCount = requestArguments.size();
751 if (requestArgumentCount != operandIndexes.size()) {
752 LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
753 << "s but the model has " << operandIndexes.size();
754 return false;
755 }
756 for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
757 requestArgumentIndex++) {
758 const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
759 const DataLocation& location = requestArgument.location;
760 // Get the operand index for this argument. We extract it from the list
761 // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
762 // We assume in this function that the model has been validated already.
763 const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
764 const Operand& operand = operands[operandIndex];
765 if (requestArgument.hasNoValue) {
766 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
767 requestArgument.dimensions.size() != 0) {
768 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
769 << " has no value yet has details.";
770 return false;
771 }
772 } else {
773 // Validate the location.
774 if (!poolVerifier.validate(location)) {
775 return false;
776 }
777 // If the argument specified a dimension, validate it.
778 uint32_t modelRank = operand.dimensions.size();
779 uint32_t requestRank = requestArgument.dimensions.size();
780 if (requestRank == 0) {
781 if (!allowUnspecified) {
782 // Validate that all the dimensions are specified in the model.
783 for (size_t i = 0; i < modelRank; i++) {
784 if (operand.dimensions[i] == 0) {
785 LOG(ERROR) << "Model has dimension " << i
786 << " set to 0 but the request does specify the dimension.";
787 return false;
788 }
789 }
790 }
791 } else {
792 if (modelRank != 0 && requestRank != modelRank) {
793 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
794 << " has number of dimensions (" << requestRank
795 << ") different than the model's (" << modelRank << ")";
796 return false;
797 }
798 for (size_t i = 0; i < requestRank; i++) {
799 if (modelRank != 0 && requestArgument.dimensions[i] != operand.dimensions[i] &&
800 operand.dimensions[i] != 0) {
801 LOG(ERROR)
802 << "Request " << type << " " << requestArgumentIndex
803 << " has dimension " << i << " of " << requestArgument.dimensions[i]
804 << " different than the model's " << operand.dimensions[i];
805 return false;
806 }
807 if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
808 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
809 << " has dimension " << i << " of zero";
810 return false;
811 }
812 }
813 }
814 }
815 }
816 return true;
817 }
818
819 template <class T_Request, class T_Model>
validateRequest(const T_Request & request,const T_Model & model,bool allowUnspecifiedOutput)820 bool validateRequest(const T_Request& request, const T_Model& model, bool allowUnspecifiedOutput) {
821 HalVersion version = ModelToHalVersion<T_Model>::version;
822 MemoryAccessVerifier poolVerifier(request.pools);
823 return (validateRequestArguments(request.inputs, model.inputIndexes,
824 convertToV1_3(model.operands), poolVerifier,
825 /*allowUnspecified=*/false, "input") &&
826 validateRequestArguments(
827 request.outputs, model.outputIndexes, convertToV1_3(model.operands),
828 poolVerifier,
829 /*allowUnspecified=*/version >= HalVersion::V1_2 && allowUnspecifiedOutput,
830 "output") &&
831 validatePools(request.pools, version));
832 }
833
834 template bool validateRequest<V1_0::Request, V1_0::Model>(const V1_0::Request& request,
835 const V1_0::Model& model,
836 bool allowUnspecifiedOutput);
837 template bool validateRequest<V1_0::Request, V1_1::Model>(const V1_0::Request& request,
838 const V1_1::Model& model,
839 bool allowUnspecifiedOutput);
840 template bool validateRequest<V1_0::Request, V1_2::Model>(const V1_0::Request& request,
841 const V1_2::Model& model,
842 bool allowUnspecifiedOutput);
843
844 template <>
validateRequest(const V1_3::Request & request,const V1_3::Model & model,bool allowUnspecifiedOutput)845 bool validateRequest(const V1_3::Request& request, const V1_3::Model& model,
846 bool allowUnspecifiedOutput) {
847 return (validateRequestArguments(request.inputs, model.main.inputIndexes, model.main.operands,
848 request.pools,
849 /*allowUnspecified=*/false, "input") &&
850 validateRequestArguments(request.outputs, model.main.outputIndexes, model.main.operands,
851 request.pools, allowUnspecifiedOutput, "output") &&
852 validatePools(request.pools, HalVersion::V1_3));
853 }
854
validateMemoryDesc(const V1_3::BufferDesc & desc,const hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hidl_vec<V1_3::BufferRole> & inputRoles,const hidl_vec<V1_3::BufferRole> & outputRoles,std::function<const V1_3::Model * (const sp<V1_3::IPreparedModel> &)> getModel,std::set<PreparedModelRole> * preparedModelRoles,V1_3::Operand * combinedOperand)855 bool validateMemoryDesc(const V1_3::BufferDesc& desc,
856 const hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
857 const hidl_vec<V1_3::BufferRole>& inputRoles,
858 const hidl_vec<V1_3::BufferRole>& outputRoles,
859 std::function<const V1_3::Model*(const sp<V1_3::IPreparedModel>&)> getModel,
860 std::set<PreparedModelRole>* preparedModelRoles,
861 V1_3::Operand* combinedOperand) {
862 NN_RET_CHECK(preparedModels.size() != 0);
863 NN_RET_CHECK(inputRoles.size() != 0 || outputRoles.size() != 0);
864
865 std::set<PreparedModelRole> roles;
866 std::vector<V1_3::Operand> operands;
867 operands.reserve(inputRoles.size() + outputRoles.size());
868 for (const auto& role : inputRoles) {
869 NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
870 const auto& preparedModel = preparedModels[role.modelIndex];
871 NN_RET_CHECK(preparedModel != nullptr);
872 const auto* model = getModel(preparedModel);
873 NN_RET_CHECK(model != nullptr);
874 const auto& inputIndexes = model->main.inputIndexes;
875 NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
876 NN_RET_CHECK_GT(role.frequency, 0.0f);
877 NN_RET_CHECK_LE(role.frequency, 1.0f);
878 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
879 NN_RET_CHECK(success);
880 operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
881 }
882 for (const auto& role : outputRoles) {
883 NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
884 const auto& preparedModel = preparedModels[role.modelIndex];
885 NN_RET_CHECK(preparedModel != nullptr);
886 const auto* model = getModel(preparedModel);
887 NN_RET_CHECK(model != nullptr);
888 const auto& outputIndexes = model->main.outputIndexes;
889 NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
890 NN_RET_CHECK_GT(role.frequency, 0.0f);
891 NN_RET_CHECK_LE(role.frequency, 1.0f);
892 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
893 NN_RET_CHECK(success);
894 operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
895 }
896
897 CHECK(!operands.empty());
898 const auto opType = operands[0].type;
899 const bool isExtension = isExtensionOperandType(opType);
900
901 std::vector<uint32_t> dimensions = desc.dimensions;
902 for (const auto& operand : operands) {
903 NN_RET_CHECK(operand.type == operands[0].type)
904 << toString(operand.type) << " vs " << toString(operands[0].type);
905 NN_RET_CHECK_EQ(operand.scale, operands[0].scale);
906 NN_RET_CHECK_EQ(operand.zeroPoint, operands[0].zeroPoint);
907 // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
908 if (!isExtension) {
909 NN_RET_CHECK(operand.extraParams == operands[0].extraParams)
910 << toString(operand.extraParams) << " vs " << toString(operands[0].extraParams);
911 }
912 const auto combined = combineDimensions(dimensions, operand.dimensions);
913 NN_RET_CHECK(combined.has_value());
914 dimensions = combined.value();
915 }
916
917 // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
918 if (!isExtension) {
919 NN_RET_CHECK(!nonExtensionOperandTypeIsScalar(static_cast<int>(opType)) ||
920 dimensions.empty())
921 << "invalid dimensions with scalar operand type.";
922 }
923
924 if (preparedModelRoles != nullptr) {
925 *preparedModelRoles = std::move(roles);
926 }
927 if (combinedOperand != nullptr) {
928 *combinedOperand = operands[0];
929 combinedOperand->dimensions = dimensions;
930 }
931 return true;
932 }
933
validateExecutionPreference(ExecutionPreference preference)934 bool validateExecutionPreference(ExecutionPreference preference) {
935 return preference == ExecutionPreference::LOW_POWER ||
936 preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
937 preference == ExecutionPreference::SUSTAINED_SPEED;
938 }
939
validatePriority(Priority priority)940 bool validatePriority(Priority priority) {
941 return priority == Priority::LOW || priority == Priority::MEDIUM || priority == Priority::HIGH;
942 }
943
validOperandType(V1_0::OperandType operandType)944 bool validOperandType(V1_0::OperandType operandType) {
945 switch (operandType) {
946 case V1_0::OperandType::FLOAT32:
947 case V1_0::OperandType::INT32:
948 case V1_0::OperandType::UINT32:
949 case V1_0::OperandType::TENSOR_FLOAT32:
950 case V1_0::OperandType::TENSOR_INT32:
951 case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
952 case V1_0::OperandType::OEM:
953 case V1_0::OperandType::TENSOR_OEM_BYTE:
954 return true;
955 default:
956 return false;
957 }
958 }
959
validOperandType(V1_2::OperandType operandType)960 bool validOperandType(V1_2::OperandType operandType) {
961 switch (operandType) {
962 case V1_2::OperandType::FLOAT16:
963 case V1_2::OperandType::FLOAT32:
964 case V1_2::OperandType::INT32:
965 case V1_2::OperandType::UINT32:
966 case V1_2::OperandType::BOOL:
967 case V1_2::OperandType::TENSOR_FLOAT16:
968 case V1_2::OperandType::TENSOR_FLOAT32:
969 case V1_2::OperandType::TENSOR_INT32:
970 case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
971 case V1_2::OperandType::TENSOR_QUANT8_SYMM:
972 case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
973 case V1_2::OperandType::TENSOR_QUANT16_SYMM:
974 case V1_2::OperandType::TENSOR_BOOL8:
975 case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
976 case V1_2::OperandType::OEM:
977 case V1_2::OperandType::TENSOR_OEM_BYTE:
978 return true;
979 default:
980 return isExtensionOperandType(static_cast<V1_3::OperandType>(operandType));
981 }
982 }
983
validOperandType(V1_3::OperandType operandType)984 bool validOperandType(V1_3::OperandType operandType) {
985 switch (operandType) {
986 case V1_3::OperandType::FLOAT16:
987 case V1_3::OperandType::FLOAT32:
988 case V1_3::OperandType::INT32:
989 case V1_3::OperandType::UINT32:
990 case V1_3::OperandType::BOOL:
991 case V1_3::OperandType::TENSOR_FLOAT16:
992 case V1_3::OperandType::TENSOR_FLOAT32:
993 case V1_3::OperandType::TENSOR_INT32:
994 case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
995 case V1_3::OperandType::TENSOR_QUANT8_SYMM:
996 case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
997 case V1_3::OperandType::TENSOR_QUANT16_SYMM:
998 case V1_3::OperandType::TENSOR_BOOL8:
999 case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
1000 case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
1001 case V1_3::OperandType::SUBGRAPH:
1002 case V1_3::OperandType::OEM:
1003 case V1_3::OperandType::TENSOR_OEM_BYTE:
1004 return true;
1005 default:
1006 return isExtensionOperandType(operandType);
1007 }
1008 }
1009
1010 } // namespace nn
1011 } // namespace android
1012