• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ValidateHal"
18 
19 #include "ValidateHal.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <set>
25 #include <utility>
26 #include <vector>
27 
28 #include "NeuralNetworks.h"
29 #include "OperationsUtils.h"
30 #include "Tracing.h"
31 #include "Utils.h"
32 #include "nnapi/TypeUtils.h"
33 
34 namespace android {
35 namespace nn {
36 
37 template <class T_Model>
38 struct ModelToHalVersion;
39 template <>
40 struct ModelToHalVersion<V1_0::Model> {
41     static constexpr HalVersion version = HalVersion::V1_0;
42 };
43 template <>
44 struct ModelToHalVersion<V1_1::Model> {
45     static constexpr HalVersion version = HalVersion::V1_1;
46 };
47 template <>
48 struct ModelToHalVersion<V1_2::Model> {
49     static constexpr HalVersion version = HalVersion::V1_2;
50 };
51 template <>
52 struct ModelToHalVersion<V1_3::Model> {
53     static constexpr HalVersion version = HalVersion::V1_3;
54 };
55 
56 class MemoryAccessVerifier {
57    public:
MemoryAccessVerifier(const hardware::hidl_vec<hardware::hidl_memory> & pools)58     MemoryAccessVerifier(const hardware::hidl_vec<hardware::hidl_memory>& pools)
59         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
60         for (size_t i = 0; i < mPoolCount; i++) {
61             mPoolSizes[i] = pools[i].size();
62         }
63     }
MemoryAccessVerifier(const hardware::hidl_vec<V1_3::Request::MemoryPool> & pools)64     MemoryAccessVerifier(const hardware::hidl_vec<V1_3::Request::MemoryPool>& pools)
65         : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
66         for (size_t i = 0; i < mPoolCount; i++) {
67             switch (pools[i].getDiscriminator()) {
68                 case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
69                     mPoolSizes[i] = pools[i].hidlMemory().size();
70                     break;
71                 case V1_3::Request::MemoryPool::hidl_discriminator::token:
72                     // Set size to 0 to enforce length == 0 && offset == 0.
73                     mPoolSizes[i] = 0;
74                     break;
75             }
76         }
77     }
validate(const V1_0::DataLocation & location) const78     bool validate(const V1_0::DataLocation& location) const {
79         if (location.poolIndex >= mPoolCount) {
80             LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
81             return false;
82         }
83         const size_t size = mPoolSizes[location.poolIndex];
84         // Do the addition using size_t to avoid potential wrap-around problems.
85         if (static_cast<size_t>(location.offset) + location.length > size) {
86             LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
87                        << location.offset << " and length " << location.length
88                        << " exceeds pool size of " << size;
89             return false;
90         }
91         return true;
92     }
93 
94    private:
95     size_t mPoolCount;
96     std::vector<size_t> mPoolSizes;
97 };
98 
validateOperandExtraParams(const V1_3::Operand & operand,uint32_t index)99 static bool validateOperandExtraParams(const V1_3::Operand& operand, uint32_t index) {
100     switch (operand.type) {
101         case V1_3::OperandType::FLOAT32:
102         case V1_3::OperandType::INT32:
103         case V1_3::OperandType::UINT32:
104         case V1_3::OperandType::BOOL:
105         case V1_3::OperandType::SUBGRAPH:
106         case V1_3::OperandType::TENSOR_FLOAT32:
107         case V1_3::OperandType::TENSOR_FLOAT16:
108         case V1_3::OperandType::TENSOR_INT32:
109         case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
110         case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
111         case V1_3::OperandType::TENSOR_QUANT8_SYMM:
112         case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
113         case V1_3::OperandType::TENSOR_QUANT16_SYMM:
114         case V1_3::OperandType::TENSOR_BOOL8: {
115             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
116                          V1_2::Operand::ExtraParams::hidl_discriminator::none)
117                     << "Operand " << index << ": Operand of type "
118                     << getOperandTypeName(operand.type)
119                     << " has incorrect extraParams: " << toString(operand.extraParams);
120         } break;
121         case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
122             NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
123                          V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant)
124                     << "Operand " << index << ": Operand of type "
125                     << getOperandTypeName(operand.type) << " without a Channel Quantization params";
126             auto& channelQuant = operand.extraParams.channelQuant();
127 
128             size_t count = operand.dimensions.size();
129             NN_RET_CHECK_LT(channelQuant.channelDim, count)
130                     << "Operand " << index << ": Operand of type "
131                     << getOperandTypeName(operand.type)
132                     << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
133                     << ", must be valid dimension index in range [0, " << count << ")";
134             uint32_t expected = operand.dimensions[channelQuant.channelDim];
135             NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
136                     << "Operand " << index << ": Operand of type "
137                     << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
138                     << "expected " << expected << " was " << channelQuant.scales.size();
139             NN_RET_CHECK_NE(expected, 0u)
140                     << "Operand " << index << ": Operand of type "
141                     << getOperandTypeName(operand.type) << " channel dimension "
142                     << channelQuant.channelDim << " is underspecified (can't be 0)";
143             for (uint32_t i = 0; i < expected; ++i) {
144                 NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
145                         << "Operand " << index << ": Operand of type "
146                         << getOperandTypeName(operand.type) << " with a negative value in scales["
147                         << i << "]=" << channelQuant.scales[i];
148             }
149         } break;
150         default: {
151             if (isExtensionOperandType(operand.type)) {
152                 NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
153                                      V1_2::Operand::ExtraParams::hidl_discriminator::extension ||
154                              operand.extraParams.getDiscriminator() ==
155                                      V1_2::Operand::ExtraParams::hidl_discriminator::none)
156                         << "Operand " << index << ": Extension operand of type "
157                         << getOperandTypeName(operand.type)
158                         << " has incorrect extraParams: " << toString(operand.extraParams);
159             }
160             // No validation for OEM types.
161         } break;
162     }
163     return true;
164 }
165 
166 template <typename VersionedOperand>
validateOperands(const hardware::hidl_vec<VersionedOperand> & operands,const hardware::hidl_vec<uint8_t> & operandValues,const hardware::hidl_vec<hardware::hidl_memory> & pools,const hardware::hidl_vec<V1_3::Subgraph> & subgraphs,bool allowUnspecifiedRank)167 static bool validateOperands(const hardware::hidl_vec<VersionedOperand>& operands,
168                              const hardware::hidl_vec<uint8_t>& operandValues,
169                              const hardware::hidl_vec<hardware::hidl_memory>& pools,
170                              const hardware::hidl_vec<V1_3::Subgraph>& subgraphs,
171                              bool allowUnspecifiedRank) {
172     uint32_t index = 0;
173     MemoryAccessVerifier poolVerifier(pools);
174     for (auto& versionedOperand : operands) {
175         if (!validOperandType(versionedOperand.type)) {
176             LOG(ERROR) << "Operand is not supported by this version: "
177                        << toString(versionedOperand.type);
178             return false;
179         }
180         // Once we are sure the operand is supported by its version, it is safe
181         // to convert it to the latest version for the rest of the validations.
182         V1_3::Operand operand = convertToV1_3(versionedOperand);
183         // Validate type and dimensions.
184         switch (operand.type) {
185             case V1_3::OperandType::FLOAT16:
186             case V1_3::OperandType::FLOAT32:
187             case V1_3::OperandType::INT32:
188             case V1_3::OperandType::UINT32:
189             case V1_3::OperandType::BOOL:
190             case V1_3::OperandType::SUBGRAPH:
191             case V1_3::OperandType::OEM: {
192                 size_t count = operand.dimensions.size();
193                 if (count != 0) {
194                     LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
195                                << count;
196                     return false;
197                 }
198                 break;
199             }
200             case V1_3::OperandType::TENSOR_FLOAT16:
201             case V1_3::OperandType::TENSOR_FLOAT32:
202             case V1_3::OperandType::TENSOR_INT32:
203             case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
204             case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
205             case V1_3::OperandType::TENSOR_QUANT8_SYMM:
206             case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
207             case V1_3::OperandType::TENSOR_QUANT16_SYMM:
208             case V1_3::OperandType::TENSOR_BOOL8:
209             case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
210             case V1_3::OperandType::TENSOR_OEM_BYTE: {
211                 if ((!allowUnspecifiedRank ||
212                      operand.lifetime == V1_3::OperandLifeTime::CONSTANT_COPY ||
213                      operand.lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE) &&
214                     operand.dimensions.size() == 0) {
215                     LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
216                     return false;
217                 }
218                 break;
219             }
220             default: {
221                 if (!isExtensionOperandType(operand.type)) {
222                     LOG(ERROR) << "Operand " << index << ": Invalid operand type "
223                                << toString(operand.type);
224                     return false;
225                 }
226             } break;
227         }
228 
229         // Validate the scale.
230         switch (operand.type) {
231             case V1_3::OperandType::FLOAT16:
232             case V1_3::OperandType::FLOAT32:
233             case V1_3::OperandType::INT32:
234             case V1_3::OperandType::UINT32:
235             case V1_3::OperandType::BOOL:
236             case V1_3::OperandType::SUBGRAPH:
237             case V1_3::OperandType::TENSOR_FLOAT16:
238             case V1_3::OperandType::TENSOR_FLOAT32:
239             case V1_3::OperandType::TENSOR_BOOL8:
240             case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
241                 if (operand.scale != 0.f) {
242                     LOG(ERROR) << "Operand " << index << ": Operand of type "
243                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
244                                << operand.scale << ")";
245                     return false;
246                 }
247                 break;
248             case V1_3::OperandType::TENSOR_INT32:
249                 // TENSOR_INT32 may be used with or without scale, depending on the operation.
250                 if (operand.scale < 0.f) {
251                     LOG(ERROR) << "Operand " << index << ": Operand of type "
252                                << getOperandTypeName(operand.type) << " with a negative scale";
253                     return false;
254                 }
255                 break;
256             case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
257             case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
258             case V1_3::OperandType::TENSOR_QUANT8_SYMM:
259             case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
260             case V1_3::OperandType::TENSOR_QUANT16_SYMM:
261                 if (operand.scale <= 0.f) {
262                     LOG(ERROR) << "Operand " << index << ": Operand of type "
263                                << getOperandTypeName(operand.type) << " with a non-positive scale";
264                     return false;
265                 }
266                 break;
267             default:
268                 if (isExtensionOperandType(operand.type) && operand.scale != 0.f) {
269                     LOG(ERROR) << "Operand " << index << ": Operand of type "
270                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
271                                << operand.scale << ")";
272                     return false;
273                 }
274                 // No validation for OEM types.
275                 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
276                 break;
277         }
278 
279         // Validate the zeroPoint.
280         switch (operand.type) {
281             case V1_3::OperandType::FLOAT16:
282             case V1_3::OperandType::FLOAT32:
283             case V1_3::OperandType::INT32:
284             case V1_3::OperandType::UINT32:
285             case V1_3::OperandType::BOOL:
286             case V1_3::OperandType::SUBGRAPH:
287             case V1_3::OperandType::TENSOR_FLOAT16:
288             case V1_3::OperandType::TENSOR_FLOAT32:
289             case V1_3::OperandType::TENSOR_INT32:
290             case V1_3::OperandType::TENSOR_BOOL8:
291             case V1_3::OperandType::TENSOR_QUANT8_SYMM:
292             case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
293                 if (operand.zeroPoint != 0) {
294                     LOG(ERROR) << "Operand " << index << ": Operand of type "
295                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
296                                << operand.zeroPoint;
297                     return false;
298                 }
299                 break;
300             case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
301                 if (operand.zeroPoint < 0 || operand.zeroPoint > 255) {
302                     LOG(ERROR) << "Operand " << index << ": Operand of type "
303                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
304                                << operand.zeroPoint << ", must be in range [0, 255]";
305                     return false;
306                 }
307                 break;
308             case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
309                 if (operand.zeroPoint < -128 || operand.zeroPoint > 127) {
310                     LOG(ERROR) << "Operand " << index << ": Operand of type "
311                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
312                                << operand.zeroPoint << ", must be in range [-128, 127]";
313                     return false;
314                 }
315                 break;
316             case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
317                 if (operand.zeroPoint < 0 || operand.zeroPoint > 65535) {
318                     LOG(ERROR) << "Operand " << index << ": Operand of type "
319                                << getOperandTypeName(operand.type) << " with an invalid zeroPoint "
320                                << operand.zeroPoint << ", must be in range [0, 65535]";
321                     return false;
322                 }
323                 break;
324             case V1_3::OperandType::TENSOR_QUANT16_SYMM:
325                 if (operand.zeroPoint != 0) {
326                     LOG(ERROR) << "Operand " << index << ": Operand of type "
327                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
328                                << operand.zeroPoint;
329                     return false;
330                 }
331                 break;
332             default:
333                 if (isExtensionOperandType(operand.type) && operand.zeroPoint != 0) {
334                     LOG(ERROR) << "Operand " << index << ": Operand of type "
335                                << getOperandTypeName(operand.type) << " with a non-zero zeroPoint "
336                                << operand.zeroPoint;
337                     return false;
338                 }
339                 // No validation for OEM types.
340                 break;
341         }
342 
343         NN_RET_CHECK(validateOperandExtraParams(operand, index));
344 
345         // Validate the lifetime and the location.
346         const V1_0::DataLocation& location = operand.location;
347         switch (operand.lifetime) {
348             case V1_3::OperandLifeTime::CONSTANT_COPY:
349                 if (location.poolIndex != 0) {
350                     LOG(ERROR) << "Operand " << index
351                                << ": CONSTANT_COPY with a non-zero poolIndex "
352                                << location.poolIndex;
353                     return false;
354                 }
355                 // Do the addition using size_t to avoid potential wrap-around problems.
356                 if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
357                     LOG(ERROR) << "Operand " << index
358                                << ": OperandValue location out of range.  Starts at "
359                                << location.offset << ", length " << location.length << ", max "
360                                << operandValues.size();
361                     return false;
362                 }
363                 break;
364             case V1_3::OperandLifeTime::CONSTANT_REFERENCE:
365                 if (!poolVerifier.validate(location)) {
366                     return false;
367                 }
368                 break;
369             case V1_3::OperandLifeTime::TEMPORARY_VARIABLE:
370             case V1_3::OperandLifeTime::SUBGRAPH_INPUT:
371             case V1_3::OperandLifeTime::SUBGRAPH_OUTPUT:
372             case V1_3::OperandLifeTime::NO_VALUE:
373                 if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
374                     LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
375                                << location.poolIndex << ", offset " << location.offset
376                                << ", or length " << location.length << " for operand of lifetime "
377                                << toString(operand.lifetime);
378                     return false;
379                 }
380                 break;
381             case V1_3::OperandLifeTime::SUBGRAPH: {
382                 if (location.poolIndex != 0) {
383                     LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
384                                << location.poolIndex;
385                     return false;
386                 }
387                 if (location.offset >= subgraphs.size()) {
388                     LOG(ERROR) << "Model::Subgraph index out of range: " << location.offset
389                                << " >= " << subgraphs.size();
390                     return false;
391                 }
392                 if (location.length != 0) {
393                     LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
394                                << location.length;
395                     return false;
396                 }
397             } break;
398             default:
399                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
400                            << toString(operand.lifetime);
401                 return false;
402         }
403 
404         // Make sure SUBGRAPH operand type and lifetime always go together.
405         if ((operand.type == V1_3::OperandType::SUBGRAPH) !=
406             (operand.lifetime == V1_3::OperandLifeTime::SUBGRAPH)) {
407             LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
408                        << " cannot have lifetime " << toString(operand.lifetime);
409             return false;
410         }
411 
412         // For constants, validate that the length is as expected. The other lifetimes
413         // expect the length to be 0. Don't validate for OEM types.
414         if (operand.lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE ||
415             operand.lifetime == V1_3::OperandLifeTime::CONSTANT_COPY) {
416             if (!isExtensionOperandType(operand.type) && operand.type != V1_3::OperandType::OEM &&
417                 operand.type != V1_3::OperandType::TENSOR_OEM_BYTE) {
418                 uint32_t expectedLength = nonExtensionOperandSizeOfData(operand);
419                 if (location.length != expectedLength) {
420                     LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
421                                << " expected a size of " << expectedLength << " but got "
422                                << location.length;
423                     return false;
424                 }
425             }
426         }
427 
428         index++;
429     }
430     return true;
431 }
432 
getHalVersion(const V1_0::Operation &)433 static HalVersion getHalVersion(const V1_0::Operation&) {
434     return HalVersion::V1_0;
435 }
436 
getHalVersion(const V1_1::Operation &)437 static HalVersion getHalVersion(const V1_1::Operation&) {
438     return HalVersion::V1_1;
439 }
440 
getHalVersion(const V1_2::Operation &)441 static HalVersion getHalVersion(const V1_2::Operation&) {
442     return HalVersion::V1_2;
443 }
444 
getHalVersion(const V1_3::Operation &)445 static HalVersion getHalVersion(const V1_3::Operation&) {
446     return HalVersion::V1_3;
447 }
448 
449 template <typename VersionedOperation>
validateOperations(const hardware::hidl_vec<VersionedOperation> & operations,const hardware::hidl_vec<V1_3::Operand> & operands,const hardware::hidl_vec<V1_3::Subgraph> & subgraphs,ValidationMode mode)450 static bool validateOperations(const hardware::hidl_vec<VersionedOperation>& operations,
451                                const hardware::hidl_vec<V1_3::Operand>& operands,
452                                const hardware::hidl_vec<V1_3::Subgraph>& subgraphs,
453                                ValidationMode mode) {
454     auto canonicalSubgraphs = uncheckedConvert(subgraphs);
455     auto isValidSubgraphReference = [&canonicalSubgraphs](const Operand& modelOperand) -> bool {
456         NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
457                 << "Unexpected operand type: " << modelOperand.type;
458         NN_RET_CHECK_LT(modelOperand.location.offset, canonicalSubgraphs.size())
459                 << "Invalid subgraph reference";
460         return true;
461     };
462     auto getSubgraph =
463             [&canonicalSubgraphs](const Operand& modelOperand) -> const Model::Subgraph* {
464         CHECK_LT(modelOperand.location.offset, canonicalSubgraphs.size());
465         return &canonicalSubgraphs[modelOperand.location.offset];
466     };
467     auto getInputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
468         return getSubgraph(modelOperand)->inputIndexes.size();
469     };
470     auto getOutputCount = [&getSubgraph](const Operand& modelOperand) -> uint32_t {
471         return getSubgraph(modelOperand)->outputIndexes.size();
472     };
473     auto getInputOperand = [&getSubgraph](const Operand& modelOperand,
474                                           uint32_t index) -> const Operand* {
475         const Model::Subgraph& subgraph = *getSubgraph(modelOperand);
476         CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
477         return &subgraph.operands[subgraph.inputIndexes[index]];
478     };
479     auto getOutputOperand = [&getSubgraph](const Operand& modelOperand,
480                                            uint32_t index) -> const Operand* {
481         const Model::Subgraph& subgraph = *getSubgraph(modelOperand);
482         CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
483         return &subgraph.operands[subgraph.outputIndexes[index]];
484     };
485     for (auto& op : operations) {
486         // TODO Validate the shapes and any known values. This is currently
487         // done in CpuExecutor but should be done here for all drivers.
488         int error = validateOperation(static_cast<int32_t>(op.type), op.inputs.size(),
489                                       op.inputs.size() > 0 ? op.inputs.data() : nullptr,
490                                       op.outputs.size(),
491                                       op.outputs.size() > 0 ? op.outputs.data() : nullptr,
492                                       uncheckedConvert(operands), getHalVersion(op),
493                                       {.isValidSubgraphReference = isValidSubgraphReference,
494                                        .getSubgraphInputCount = getInputCount,
495                                        .getSubgraphOutputCount = getOutputCount,
496                                        .getSubgraphInputOperand = getInputOperand,
497                                        .getSubgraphOutputOperand = getOutputOperand,
498                                        // 1.3 HAL does not support CF operations with operands of
499                                        // unknown size. See http://b/132458982#comment63.
500                                        .allowControlFlowOperationWithOperandOfUnknownSize =
501                                                mode == ValidationMode::RUNTIME});
502         if (error != ANEURALNETWORKS_NO_ERROR) {
503             LOG(ERROR) << "Invalid operation " << toString(op.type);
504             return false;
505         }
506 
507         // This is redundant because of the checks in validateGraph(),
508         // but it is retained here in order to emit more informative
509         // error messages.
510         for (uint32_t i : op.outputs) {
511             const V1_3::Operand& operand = operands[i];
512             if (operand.lifetime != V1_3::OperandLifeTime::TEMPORARY_VARIABLE &&
513                 operand.lifetime != V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) {
514                 LOG(ERROR) << "Writing to operand " << i << " with incompatible lifetime "
515                            << toString(operand.lifetime);
516                 return false;
517             }
518         }
519     }
520     return true;
521 }
522 
validatePool(const hardware::hidl_memory & pool,HalVersion ver)523 bool validatePool(const hardware::hidl_memory& pool, HalVersion ver) {
524     const auto& name = pool.name();
525     if (name != "ashmem" && name != "mmap_fd" &&
526         ((ver < HalVersion::V1_2) ||
527          (name != "hardware_buffer_blob" && name != "hardware_buffer"))) {
528         LOG(ERROR) << "Unsupported memory type " << name;
529         return false;
530     }
531     if (pool.handle() == nullptr) {
532         LOG(ERROR) << "Memory of type " << name << " is null";
533         return false;
534     }
535     return true;
536 }
537 
validatePool(const V1_3::Request::MemoryPool & pool,HalVersion ver)538 bool validatePool(const V1_3::Request::MemoryPool& pool, HalVersion ver) {
539     switch (pool.getDiscriminator()) {
540         case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory:
541             return validatePool(pool.hidlMemory(), ver);
542         case V1_3::Request::MemoryPool::hidl_discriminator::token:
543             return pool.token() > 0;
544     }
545     LOG(FATAL) << "unknown MemoryPool discriminator";
546     return false;
547 }
548 
549 template <class T_MemoryPool>
validatePools(const hardware::hidl_vec<T_MemoryPool> & pools,HalVersion ver)550 static bool validatePools(const hardware::hidl_vec<T_MemoryPool>& pools, HalVersion ver) {
551     return std::all_of(pools.begin(), pools.end(),
552                        [ver](const auto& pool) { return validatePool(pool, ver); });
553 }
554 
validateModelInputOutputs(const hardware::hidl_vec<uint32_t> indexes,const hardware::hidl_vec<V1_3::Operand> & operands,V1_3::OperandLifeTime lifetime)555 static bool validateModelInputOutputs(const hardware::hidl_vec<uint32_t> indexes,
556                                       const hardware::hidl_vec<V1_3::Operand>& operands,
557                                       V1_3::OperandLifeTime lifetime) {
558     const size_t operandCount = operands.size();
559     for (uint32_t i : indexes) {
560         if (i >= operandCount) {
561             LOG(ERROR) << "Model input or output index out of range: " << i << "/" << operandCount;
562             return false;
563         }
564         const V1_3::Operand& operand = operands[i];
565         if (operand.lifetime != lifetime) {
566             LOG(ERROR) << "Model input or output operand " << i << " has lifetime of "
567                        << toString(operand.lifetime) << " instead of the expected "
568                        << toString(lifetime);
569             return false;
570         }
571     }
572 
573     std::vector<uint32_t> sortedIndexes = indexes;
574     std::sort(sortedIndexes.begin(), sortedIndexes.end());
575     auto adjacentI = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
576     if (adjacentI != sortedIndexes.end()) {
577         LOG(ERROR) << "Model input or output occurs multiple times: " << *adjacentI;
578         return false;
579     }
580 
581     for (size_t i = 0; i < operands.size(); ++i) {
582         if (operands[i].lifetime == lifetime &&
583             !binary_search(sortedIndexes.begin(), sortedIndexes.end(), i)) {
584             LOG(ERROR) << "Operand " << i << " marked as " << toString(lifetime)
585                        << " but is not included in Model input or output indexes";
586             return false;
587         }
588     }
589 
590     return true;
591 }
592 
593 template <typename VersionedModelOrSubgraph>
validateGraph(const VersionedModelOrSubgraph & model)594 static bool validateGraph(const VersionedModelOrSubgraph& model) {
595     // set up counts
596     std::vector<uint32_t> operandNumberOfConsumers(model.operands.size(), 0);
597     //     Either the operand has a known value before model execution
598     //     begins, or we've seen a writer for this operand while
599     //     walking operands in execution order.
600     std::vector<bool> operandValueKnown(model.operands.size(), false);
601 
602     // mark known operands
603     for (size_t i = 0; i < model.operands.size(); ++i) {
604         const auto& operand = model.operands[i];
605         const V1_3::OperandLifeTime lifetime = convertToV1_3(operand.lifetime);
606         operandValueKnown[i] = lifetime == V1_3::OperandLifeTime::SUBGRAPH_INPUT ||
607                                lifetime == V1_3::OperandLifeTime::CONSTANT_COPY ||
608                                lifetime == V1_3::OperandLifeTime::CONSTANT_REFERENCE ||
609                                lifetime == V1_3::OperandLifeTime::NO_VALUE ||
610                                lifetime == V1_3::OperandLifeTime::SUBGRAPH;
611     }
612 
613     // Validate that operations are sorted into execution order.
614     //
615     // If there is a cycle in the graph, the operations will not
616     // appear to be sorted into execution order: Some operation will
617     // have an input for which operandValueKnown[] is false.
618     for (size_t i = 0; i < model.operations.size(); ++i) {
619         const auto& operation = model.operations[i];
620 
621         for (size_t j = 0; j < operation.inputs.size(); ++j) {
622             uint32_t k = operation.inputs[j];
623             if (!operandValueKnown[k]) {
624                 LOG(ERROR) << "Operation " << i << " input " << j << " (operand " << k
625                            << ") is read before it is written";
626                 return false;
627             }
628             operandNumberOfConsumers[k]++;
629         }
630 
631         for (size_t j = 0; j < operation.outputs.size(); ++j) {
632             uint32_t k = operation.outputs[j];
633             if (operandValueKnown[k]) {
634                 // Assuming validateOperations() has returned true, we
635                 // know that this output is TEMPORARY_VARIABLE or
636                 // MODEL_OUTPUT, and so the only way
637                 // operandValueKnown[k] can be true is if we've
638                 // already seen a writer for this operand.
639                 LOG(ERROR) << "Operation " << i << " output " << j << " (operand " << k
640                            << ") has already been written";
641                 return false;
642             }
643             operandValueKnown[k] = true;
644         }
645     }
646 
647     // validate number of consumers
648     //
649     // TODO Because we have to validate it, there was no point in including it
650     // in struct Operand. For the next release, consider removing unless we have
651     // an additional process in system space that creates this value. In that
652     // case, it would not have to be validated.
653     for (size_t i = 0; i < model.operands.size(); ++i) {
654         if (model.operands[i].numberOfConsumers != operandNumberOfConsumers[i]) {
655             LOG(ERROR) << "Operand " << i << " has incorrect number of consumers "
656                        << model.operands[i].numberOfConsumers << ", expected "
657                        << operandNumberOfConsumers[i];
658             return false;
659         }
660     }
661 
662     // verify all operands are written
663     for (size_t i = 0; i < model.operands.size(); ++i) {
664         if (!operandValueKnown[i]) {
665             LOG(ERROR) << "Operand " << i << " is never written";
666             return false;
667         }
668     }
669 
670     return true;
671 }
672 
673 // Makes sure the model does not contain subgraph reference cycles.
checkNoReferenceCycles(const V1_3::Model & model,const V1_3::Subgraph & subgraph,std::set<const V1_3::Subgraph * > * path)674 static bool checkNoReferenceCycles(const V1_3::Model& model, const V1_3::Subgraph& subgraph,
675                                    std::set<const V1_3::Subgraph*>* path) {
676     auto [_, isNew] = path->insert(&subgraph);
677     if (!isNew) {
678         LOG(ERROR) << "Model contains a circular subgraph reference";
679         return false;
680     }
681     for (const V1_3::Operand& operand : subgraph.operands) {
682         if (operand.lifetime == V1_3::OperandLifeTime::SUBGRAPH) {
683             uint32_t refSubgraphIndex = operand.location.offset;
684             if (!checkNoReferenceCycles(model, model.referenced[refSubgraphIndex], path)) {
685                 return false;
686             }
687         }
688     }
689     path->erase(&subgraph);
690     return true;
691 }
692 
checkNoReferenceCycles(const V1_3::Model & model)693 static bool checkNoReferenceCycles(const V1_3::Model& model) {
694     std::set<const V1_3::Subgraph*> path;
695     return checkNoReferenceCycles(model, model.main, &path);
696 }
697 
698 template <class T_Model>
validateModel(const T_Model & model,ValidationMode mode)699 bool validateModel(const T_Model& model, ValidationMode mode) {
700     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
701     HalVersion version = ModelToHalVersion<T_Model>::version;
702     if (model.operations.size() == 0 || model.operands.size() == 0) {
703         LOG(ERROR) << "Invalid empty model.";
704         return false;
705     }
706     // We only need versioned operands for their validation. For all the other
707     // validations we can use operands upcasted to the latest version.
708     const hardware::hidl_vec<V1_3::Operand> latestVersionOperands = convertToV1_3(model.operands);
709     return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
710                              /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
711             validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}, mode) &&
712             validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
713                                       V1_3::OperandLifeTime::SUBGRAPH_INPUT) &&
714             validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
715                                       V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) &&
716             validatePools(model.pools, version) && validateGraph(model));
717 }
718 
719 template bool validateModel<V1_0::Model>(const V1_0::Model& model, ValidationMode mode);
720 template bool validateModel<V1_1::Model>(const V1_1::Model& model, ValidationMode mode);
721 template bool validateModel<V1_2::Model>(const V1_2::Model& model, ValidationMode mode);
722 
723 template <>
validateModel(const V1_3::Model & model,ValidationMode mode)724 bool validateModel(const V1_3::Model& model, ValidationMode mode) {
725     NNTRACE_FULL(NNTRACE_LAYER_UTILITY, NNTRACE_PHASE_UNSPECIFIED, "validateModel");
726     if (model.main.operations.size() == 0 || model.main.operands.size() == 0) {
727         LOG(ERROR) << "Invalid empty model.";
728         return false;
729     }
730     auto validateSubgraph = [&model, mode](const V1_3::Subgraph& subgraph) -> bool {
731         return (validateOperands(subgraph.operands, model.operandValues, model.pools,
732                                  model.referenced, /*allowUnspecifiedRank=*/true) &&
733                 validateOperations(subgraph.operations, subgraph.operands, model.referenced,
734                                    mode) &&
735                 validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
736                                           V1_3::OperandLifeTime::SUBGRAPH_INPUT) &&
737                 validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
738                                           V1_3::OperandLifeTime::SUBGRAPH_OUTPUT) &&
739                 validateGraph(subgraph));
740     };
741     return (validateSubgraph(model.main) &&
742             std::all_of(model.referenced.begin(), model.referenced.end(), validateSubgraph) &&
743             validatePools(model.pools, HalVersion::V1_3) && checkNoReferenceCycles(model));
744 }
745 
746 // Validates the arguments of a request. type is either "input" or "output" and is used
747 // for printing error messages. The operandIndexes is the appropriate array of input
748 // or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
validateRequestArguments(const hardware::hidl_vec<V1_0::RequestArgument> & requestArguments,const hardware::hidl_vec<uint32_t> & operandIndexes,const hardware::hidl_vec<V1_3::Operand> & operands,const MemoryAccessVerifier & poolVerifier,bool allowUnspecified,const char * type)749 static bool validateRequestArguments(
750         const hardware::hidl_vec<V1_0::RequestArgument>& requestArguments,
751         const hardware::hidl_vec<uint32_t>& operandIndexes,
752         const hardware::hidl_vec<V1_3::Operand>& operands, const MemoryAccessVerifier& poolVerifier,
753         bool allowUnspecified, const char* type) {
754     // The request should specify as many arguments as were described in the model.
755     const size_t requestArgumentCount = requestArguments.size();
756     if (requestArgumentCount != operandIndexes.size()) {
757         LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
758                    << "s but the model has " << operandIndexes.size();
759         return false;
760     }
761     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
762          requestArgumentIndex++) {
763         const V1_0::RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
764         const V1_0::DataLocation& location = requestArgument.location;
765         // Get the operand index for this argument. We extract it from the list
766         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
767         // We assume in this function that the model has been validated already.
768         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
769         const V1_3::Operand& operand = operands[operandIndex];
770         if (requestArgument.hasNoValue) {
771             if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
772                 requestArgument.dimensions.size() != 0) {
773                 LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
774                            << " has no value yet has details.";
775                 return false;
776             }
777         } else {
778             // Validate the location.
779             if (!poolVerifier.validate(location)) {
780                 return false;
781             }
782             // If the argument specified a dimension, validate it.
783             uint32_t modelRank = operand.dimensions.size();
784             uint32_t requestRank = requestArgument.dimensions.size();
785             if (requestRank == 0) {
786                 if (!allowUnspecified) {
787                     // NOTE: validateRequestArguments cannot validate unknown tensor rank with
788                     // extension operand type.
789                     if (!isExtensionOperandType(operand.type) &&
790                         !nonExtensionOperandTypeIsScalar(static_cast<int>(operand.type))) {
791                         NN_RET_CHECK_GT(modelRank, 0u)
792                                 << "Model " << type << " " << requestArgumentIndex
793                                 << " has unknown rank but the request does not specify the rank.";
794                     }
795                     // Validate that all the dimensions are specified in the model.
796                     for (size_t i = 0; i < modelRank; i++) {
797                         if (operand.dimensions[i] == 0) {
798                             LOG(ERROR)
799                                     << "Model has dimension " << i
800                                     << " set to 0 but the request does not specify the dimension.";
801                             return false;
802                         }
803                     }
804                 }
805             } else {
806                 if (modelRank != 0 && requestRank != modelRank) {
807                     LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
808                                << " has number of dimensions (" << requestRank
809                                << ") different than the model's (" << modelRank << ")";
810                     return false;
811                 }
812                 for (size_t i = 0; i < requestRank; i++) {
813                     if (modelRank != 0 && requestArgument.dimensions[i] != operand.dimensions[i] &&
814                         operand.dimensions[i] != 0) {
815                         LOG(ERROR)
816                                 << "Request " << type << " " << requestArgumentIndex
817                                 << " has dimension " << i << " of " << requestArgument.dimensions[i]
818                                 << " different than the model's " << operand.dimensions[i];
819                         return false;
820                     }
821                     if (requestArgument.dimensions[i] == 0 && !allowUnspecified) {
822                         LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
823                                    << " has dimension " << i << " of zero";
824                         return false;
825                     }
826                 }
827             }
828         }
829     }
830     return true;
831 }
832 
833 template <class T_Request, class T_Model>
validateRequest(const T_Request & request,const T_Model & model,bool allowUnspecifiedOutput)834 bool validateRequest(const T_Request& request, const T_Model& model, bool allowUnspecifiedOutput) {
835     HalVersion version = ModelToHalVersion<T_Model>::version;
836     MemoryAccessVerifier poolVerifier(request.pools);
837     return (validateRequestArguments(request.inputs, model.inputIndexes,
838                                      convertToV1_3(model.operands), poolVerifier,
839                                      /*allowUnspecified=*/false, "input") &&
840             validateRequestArguments(
841                     request.outputs, model.outputIndexes, convertToV1_3(model.operands),
842                     poolVerifier,
843                     /*allowUnspecified=*/version >= HalVersion::V1_2 && allowUnspecifiedOutput,
844                     "output") &&
845             validatePools(request.pools, version));
846 }
847 
848 template bool validateRequest<V1_0::Request, V1_0::Model>(const V1_0::Request& request,
849                                                           const V1_0::Model& model,
850                                                           bool allowUnspecifiedOutput);
851 template bool validateRequest<V1_0::Request, V1_1::Model>(const V1_0::Request& request,
852                                                           const V1_1::Model& model,
853                                                           bool allowUnspecifiedOutput);
854 template bool validateRequest<V1_0::Request, V1_2::Model>(const V1_0::Request& request,
855                                                           const V1_2::Model& model,
856                                                           bool allowUnspecifiedOutput);
857 
858 template <>
validateRequest(const V1_3::Request & request,const V1_3::Model & model,bool allowUnspecifiedOutput)859 bool validateRequest(const V1_3::Request& request, const V1_3::Model& model,
860                      bool allowUnspecifiedOutput) {
861     return (validateRequestArguments(request.inputs, model.main.inputIndexes, model.main.operands,
862                                      request.pools,
863                                      /*allowUnspecified=*/false, "input") &&
864             validateRequestArguments(request.outputs, model.main.outputIndexes, model.main.operands,
865                                      request.pools, allowUnspecifiedOutput, "output") &&
866             validatePools(request.pools, HalVersion::V1_3));
867 }
868 
validateMemoryDesc(const V1_3::BufferDesc & desc,const hardware::hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hardware::hidl_vec<V1_3::BufferRole> & inputRoles,const hardware::hidl_vec<V1_3::BufferRole> & outputRoles,std::function<const V1_3::Model * (const sp<V1_3::IPreparedModel> &)> getModel,std::set<HalPreparedModelRole> * preparedModelRoles,V1_3::Operand * combinedOperand)869 bool validateMemoryDesc(const V1_3::BufferDesc& desc,
870                         const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
871                         const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
872                         const hardware::hidl_vec<V1_3::BufferRole>& outputRoles,
873                         std::function<const V1_3::Model*(const sp<V1_3::IPreparedModel>&)> getModel,
874                         std::set<HalPreparedModelRole>* preparedModelRoles,
875                         V1_3::Operand* combinedOperand) {
876     NN_RET_CHECK(preparedModels.size() != 0);
877     NN_RET_CHECK(inputRoles.size() != 0 || outputRoles.size() != 0);
878 
879     std::set<HalPreparedModelRole> roles;
880     std::vector<V1_3::Operand> operands;
881     operands.reserve(inputRoles.size() + outputRoles.size());
882     for (const auto& role : inputRoles) {
883         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
884         const auto& preparedModel = preparedModels[role.modelIndex];
885         NN_RET_CHECK(preparedModel != nullptr);
886         const auto* model = getModel(preparedModel);
887         NN_RET_CHECK(model != nullptr);
888         const auto& inputIndexes = model->main.inputIndexes;
889         NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
890         NN_RET_CHECK_GT(role.frequency, 0.0f);
891         NN_RET_CHECK_LE(role.frequency, 1.0f);
892         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
893         NN_RET_CHECK(success);
894         operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
895     }
896     for (const auto& role : outputRoles) {
897         NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
898         const auto& preparedModel = preparedModels[role.modelIndex];
899         NN_RET_CHECK(preparedModel != nullptr);
900         const auto* model = getModel(preparedModel);
901         NN_RET_CHECK(model != nullptr);
902         const auto& outputIndexes = model->main.outputIndexes;
903         NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
904         NN_RET_CHECK_GT(role.frequency, 0.0f);
905         NN_RET_CHECK_LE(role.frequency, 1.0f);
906         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
907         NN_RET_CHECK(success);
908         operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
909     }
910 
911     CHECK(!operands.empty());
912     const auto opType = operands[0].type;
913     const bool isExtension = isExtensionOperandType(opType);
914 
915     std::vector<uint32_t> dimensions = desc.dimensions;
916     for (const auto& operand : operands) {
917         NN_RET_CHECK(operand.type == operands[0].type)
918                 << toString(operand.type) << " vs " << toString(operands[0].type);
919         NN_RET_CHECK_EQ(operand.scale, operands[0].scale);
920         NN_RET_CHECK_EQ(operand.zeroPoint, operands[0].zeroPoint);
921         // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
922         if (!isExtension) {
923             NN_RET_CHECK(operand.extraParams == operands[0].extraParams)
924                     << toString(operand.extraParams) << " vs " << toString(operands[0].extraParams);
925         }
926         const auto combined = combineDimensions(dimensions, operand.dimensions);
927         NN_RET_CHECK(combined.has_value());
928         dimensions = combined.value();
929     }
930 
931     // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
932     if (!isExtension) {
933         NN_RET_CHECK(!nonExtensionOperandTypeIsScalar(static_cast<int>(opType)) ||
934                      dimensions.empty())
935                 << "invalid dimensions with scalar operand type.";
936     }
937 
938     if (preparedModelRoles != nullptr) {
939         *preparedModelRoles = std::move(roles);
940     }
941     if (combinedOperand != nullptr) {
942         *combinedOperand = operands[0];
943         combinedOperand->dimensions = dimensions;
944     }
945     return true;
946 }
947 
validateExecutionPreference(V1_1::ExecutionPreference preference)948 bool validateExecutionPreference(V1_1::ExecutionPreference preference) {
949     return preference == V1_1::ExecutionPreference::LOW_POWER ||
950            preference == V1_1::ExecutionPreference::FAST_SINGLE_ANSWER ||
951            preference == V1_1::ExecutionPreference::SUSTAINED_SPEED;
952 }
953 
validatePriority(V1_3::Priority priority)954 bool validatePriority(V1_3::Priority priority) {
955     return priority == V1_3::Priority::LOW || priority == V1_3::Priority::MEDIUM ||
956            priority == V1_3::Priority::HIGH;
957 }
958 
validOperandType(V1_0::OperandType operandType)959 bool validOperandType(V1_0::OperandType operandType) {
960     switch (operandType) {
961         case V1_0::OperandType::FLOAT32:
962         case V1_0::OperandType::INT32:
963         case V1_0::OperandType::UINT32:
964         case V1_0::OperandType::TENSOR_FLOAT32:
965         case V1_0::OperandType::TENSOR_INT32:
966         case V1_0::OperandType::TENSOR_QUANT8_ASYMM:
967         case V1_0::OperandType::OEM:
968         case V1_0::OperandType::TENSOR_OEM_BYTE:
969             return true;
970         default:
971             return false;
972     }
973 }
974 
validOperandType(V1_2::OperandType operandType)975 bool validOperandType(V1_2::OperandType operandType) {
976     switch (operandType) {
977         case V1_2::OperandType::FLOAT16:
978         case V1_2::OperandType::FLOAT32:
979         case V1_2::OperandType::INT32:
980         case V1_2::OperandType::UINT32:
981         case V1_2::OperandType::BOOL:
982         case V1_2::OperandType::TENSOR_FLOAT16:
983         case V1_2::OperandType::TENSOR_FLOAT32:
984         case V1_2::OperandType::TENSOR_INT32:
985         case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
986         case V1_2::OperandType::TENSOR_QUANT8_SYMM:
987         case V1_2::OperandType::TENSOR_QUANT16_ASYMM:
988         case V1_2::OperandType::TENSOR_QUANT16_SYMM:
989         case V1_2::OperandType::TENSOR_BOOL8:
990         case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
991         case V1_2::OperandType::OEM:
992         case V1_2::OperandType::TENSOR_OEM_BYTE:
993             return true;
994         default:
995             return isExtensionOperandType(static_cast<V1_3::OperandType>(operandType));
996     }
997 }
998 
validOperandType(V1_3::OperandType operandType)999 bool validOperandType(V1_3::OperandType operandType) {
1000     switch (operandType) {
1001         case V1_3::OperandType::FLOAT16:
1002         case V1_3::OperandType::FLOAT32:
1003         case V1_3::OperandType::INT32:
1004         case V1_3::OperandType::UINT32:
1005         case V1_3::OperandType::BOOL:
1006         case V1_3::OperandType::TENSOR_FLOAT16:
1007         case V1_3::OperandType::TENSOR_FLOAT32:
1008         case V1_3::OperandType::TENSOR_INT32:
1009         case V1_3::OperandType::TENSOR_QUANT8_ASYMM:
1010         case V1_3::OperandType::TENSOR_QUANT8_SYMM:
1011         case V1_3::OperandType::TENSOR_QUANT16_ASYMM:
1012         case V1_3::OperandType::TENSOR_QUANT16_SYMM:
1013         case V1_3::OperandType::TENSOR_BOOL8:
1014         case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
1015         case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
1016         case V1_3::OperandType::SUBGRAPH:
1017         case V1_3::OperandType::OEM:
1018         case V1_3::OperandType::TENSOR_OEM_BYTE:
1019             return true;
1020         default:
1021             return isExtensionOperandType(operandType);
1022     }
1023 }
1024 
1025 }  // namespace nn
1026 }  // namespace android
1027