• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Validation.h"
18 
19 #include <android-base/logging.h>
20 #include <android-base/mapped_file.h>
21 
22 #include <algorithm>
23 #include <cctype>
24 #include <functional>
25 #include <limits>
26 #include <memory>
27 #include <numeric>
28 #include <set>
29 #include <sstream>
30 #include <string>
31 #include <string_view>
32 #include <tuple>
33 #include <utility>
34 #include <variant>
35 #include <vector>
36 
37 #include "ControlFlow.h"
38 #include "OperandTypes.h"
39 #include "OperationResolver.h"
40 #include "OperationTypes.h"
41 #include "Result.h"
42 #include "SharedMemory.h"
43 #include "TypeUtils.h"
44 #include "Types.h"
45 
46 // The NN_VALIDATE family of macros defined below is similar to the CHECK family defined in
47 // system/libbase/include/android-base/logging.h
48 //
49 // The difference is that NN_VALIDATE macros use LOG(ERROR) instead of LOG(FATAL)
50 // and return false instead of aborting.
51 
52 // Logs an error and returns false or INVALID. Append context using << after. For example:
53 //
54 //   NN_VALIDATE_FAIL() << "Something went wrong";
55 //
56 // The containing function must return a bool or Version.
57 #define NN_VALIDATE_FAIL() \
58     return NN_ERROR() << "NN_VALIDATE failed (" << __FILE__ << ":" << __LINE__ << "): "
59 
60 // Logs an error and returns false or Version::INVALID if condition is false. Extra logging can be
61 // appended using << after. For example:
62 //
63 //   NN_VALIDATE(false) << "Something went wrong";
64 //
65 // The containing function must return a bool.
66 #define NN_VALIDATE(condition) \
67     while (UNLIKELY(!(condition))) NN_VALIDATE_FAIL() << #condition << " "
68 
69 // Helper for NN_VALIDATE_xx(x, y) macros.
70 #define NN_VALIDATE_OP(LHS, RHS, OP)                                                        \
71     for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS);                      \
72          UNLIKELY(!(_values.lhs.v OP _values.rhs.v));                                       \
73          /* empty */)                                                                       \
74     NN_VALIDATE_FAIL()                                                                      \
75             << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = "                   \
76             << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
77             << ", " << #RHS << " = "                                                        \
78             << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
79             << ") "
80 
81 // Logs an error and returns false or Version::INVALID if a condition between x and y does not hold.
82 // Extra logging can be appended using << after. For example:
83 //
84 //   NN_VALIDATE_EQ(a, b) << "Something went wrong";
85 //
86 // The values must implement the appropriate comparison operator as well as
87 // `operator<<(std::ostream&, ...)`.
88 // The containing function must return a bool or Version.
89 #define NN_VALIDATE_EQ(x, y) NN_VALIDATE_OP(x, y, ==)
90 #define NN_VALIDATE_NE(x, y) NN_VALIDATE_OP(x, y, !=)
91 #define NN_VALIDATE_LE(x, y) NN_VALIDATE_OP(x, y, <=)
92 #define NN_VALIDATE_LT(x, y) NN_VALIDATE_OP(x, y, <)
93 #define NN_VALIDATE_GE(x, y) NN_VALIDATE_OP(x, y, >=)
94 #define NN_VALIDATE_GT(x, y) NN_VALIDATE_OP(x, y, >)
95 
96 namespace android::nn {
97 namespace {
98 
99 constexpr auto kNullptrVariant = std::variant<const void*, void*>{};
100 constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{};
101 
102 template <typename Type, typename ValidationFunction>
validateVector(const std::vector<Type> & objects,const ValidationFunction & validationFunction)103 Result<Version> validateVector(const std::vector<Type>& objects,
104                                const ValidationFunction& validationFunction) {
105     auto version = Version::ANDROID_OC_MR1;
106     for (const auto& object : objects) {
107         version = combineVersions(version, NN_TRY(validationFunction(object)));
108     }
109     return version;
110 }
111 
isValidExtensionName(const std::string & name)112 bool isValidExtensionName(const std::string& name) {
113     constexpr auto validSymbol = [](char symbol) {
114         return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_';
115     };
116     const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol);
117     const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end();
118     return hasOnlyValidSymbols && hasAtLeastOnePeriod;
119 }
120 
validateDeviceStatus(const DeviceStatus & deviceStatus)121 Result<Version> validateDeviceStatus(const DeviceStatus& deviceStatus) {
122     switch (deviceStatus) {
123         case DeviceStatus::AVAILABLE:
124         case DeviceStatus::BUSY:
125         case DeviceStatus::OFFLINE:
126         case DeviceStatus::UNKNOWN:
127             return Version::ANDROID_OC_MR1;
128     }
129     NN_VALIDATE_FAIL() << "Invalid DeviceStatus " << deviceStatus;
130 }
131 
validateExecutionPreference(const ExecutionPreference & executionPreference)132 Result<Version> validateExecutionPreference(const ExecutionPreference& executionPreference) {
133     switch (executionPreference) {
134         case ExecutionPreference::FAST_SINGLE_ANSWER:
135             // ExecutionPreference::FAST_SINGLE_ANSWER is the default value, so it is implicitly
136             // valid for all versions.
137             return Version::ANDROID_OC_MR1;
138         case ExecutionPreference::LOW_POWER:
139         case ExecutionPreference::SUSTAINED_SPEED:
140             return Version::ANDROID_P;
141     }
142     NN_VALIDATE_FAIL() << "Invalid ExecutionPreference " << executionPreference;
143 }
144 
validateDeviceType(const DeviceType & deviceType)145 Result<Version> validateDeviceType(const DeviceType& deviceType) {
146     switch (deviceType) {
147         case DeviceType::UNKNOWN:
148             // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when
149             // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a
150             // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a
151             // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for
152             // DeviceType::UNKNOWN.
153             return Version::ANDROID_OC_MR1;
154         case DeviceType::OTHER:
155         case DeviceType::CPU:
156         case DeviceType::GPU:
157         case DeviceType::ACCELERATOR:
158             return Version::ANDROID_Q;
159     }
160     NN_VALIDATE_FAIL() << "Invalid DeviceType " << deviceType;
161 }
162 
validateMeasureTiming(const MeasureTiming & measureTiming)163 Result<Version> validateMeasureTiming(const MeasureTiming& measureTiming) {
164     switch (measureTiming) {
165         case MeasureTiming::NO:
166             // MeasureTiming::NO is the default value, so it is implicitly valid for all versions.
167             return Version::ANDROID_OC_MR1;
168         case MeasureTiming::YES:
169             return Version::ANDROID_Q;
170     }
171     NN_VALIDATE_FAIL() << "Invalid MeasureTiming " << measureTiming;
172 }
173 
validateOperandType(const OperandType & operandType)174 Result<Version> validateOperandType(const OperandType& operandType) {
175     switch (operandType) {
176         case OperandType::FLOAT32:
177         case OperandType::INT32:
178         case OperandType::UINT32:
179         case OperandType::TENSOR_FLOAT32:
180         case OperandType::TENSOR_INT32:
181         case OperandType::TENSOR_QUANT8_ASYMM:
182         case OperandType::OEM:
183         case OperandType::TENSOR_OEM_BYTE:
184             return Version::ANDROID_OC_MR1;
185         case OperandType::BOOL:
186         case OperandType::TENSOR_QUANT16_SYMM:
187         case OperandType::TENSOR_FLOAT16:
188         case OperandType::TENSOR_BOOL8:
189         case OperandType::FLOAT16:
190         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
191         case OperandType::TENSOR_QUANT16_ASYMM:
192         case OperandType::TENSOR_QUANT8_SYMM:
193             return Version::ANDROID_Q;
194         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
195         case OperandType::SUBGRAPH:
196             return Version::ANDROID_R;
197     }
198     if (isExtension(operandType)) {
199         return Version::ANDROID_Q;
200     }
201     NN_VALIDATE_FAIL() << "Invalid OperandType " << operandType;
202 }
203 
validateOperandLifeTime(const Operand & operand)204 Result<Version> validateOperandLifeTime(const Operand& operand) {
205     // Make sure SUBGRAPH operand type and lifetime always go together.
206     NN_VALIDATE_EQ((operand.type == OperandType::SUBGRAPH),
207                    (operand.lifetime == Operand::LifeTime::SUBGRAPH))
208             << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime;
209 
210     switch (operand.lifetime) {
211         case Operand::LifeTime::TEMPORARY_VARIABLE:
212         case Operand::LifeTime::SUBGRAPH_INPUT:
213         case Operand::LifeTime::SUBGRAPH_OUTPUT:
214         case Operand::LifeTime::CONSTANT_COPY:
215         case Operand::LifeTime::CONSTANT_REFERENCE:
216         case Operand::LifeTime::NO_VALUE:
217         case Operand::LifeTime::POINTER:
218             return Version::ANDROID_OC_MR1;
219         case Operand::LifeTime::SUBGRAPH:
220             return Version::ANDROID_R;
221     }
222     NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
223 }
224 
validatePriority(const Priority & priority)225 Result<Version> validatePriority(const Priority& priority) {
226     switch (priority) {
227         case Priority::MEDIUM:
228             // Priority::MEDIUM is the default value, so it is implicitly valid for all versions.
229             return Version::ANDROID_OC_MR1;
230         case Priority::LOW:
231         case Priority::HIGH:
232             return Version::ANDROID_R;
233     }
234     NN_VALIDATE_FAIL() << "Invalid Priority " << priority;
235 }
236 
validateErrorStatus(const ErrorStatus & errorStatus)237 Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) {
238     // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih
239     // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE.
240     switch (errorStatus) {
241         case ErrorStatus::NONE:
242         case ErrorStatus::DEVICE_UNAVAILABLE:
243         case ErrorStatus::GENERAL_FAILURE:
244         case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
245         case ErrorStatus::INVALID_ARGUMENT:
246         case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
247         case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
248         case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
249         case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
250         case ErrorStatus::DEAD_OBJECT:
251             return Version::ANDROID_OC_MR1;
252     }
253     NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus;
254 }
255 
validateFusedActivationFunc(const FusedActivationFunc & activation)256 Result<Version> validateFusedActivationFunc(const FusedActivationFunc& activation) {
257     switch (activation) {
258         case FusedActivationFunc::NONE:
259         case FusedActivationFunc::RELU:
260         case FusedActivationFunc::RELU1:
261         case FusedActivationFunc::RELU6:
262             return Version::ANDROID_OC_MR1;
263     }
264     NN_VALIDATE_FAIL() << "Invalid FusedActivationFunc " << activation;
265 }
266 
validateOutputShape(const OutputShape &)267 Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) {
268     return Version::ANDROID_Q;
269 }
270 
validateTiming(const Timing & timing)271 Result<Version> validateTiming(const Timing& timing) {
272     constexpr auto kNoTiming = Timing{};
273     if (timing == kNoTiming) {
274         // kNoTiming is the default value, so it is implicitly valid for all versions.
275         return Version::ANDROID_OC_MR1;
276     }
277     if (timing.timeInDriver.has_value() && timing.timeOnDevice.has_value()) {
278         // `lazyMessage` is a lazy function to produce the timing validation error message.
279         // Currently, the code is not able to inline the message in NN_VALIDATE due to a
280         // argument-dependent lookup issue with nn::detail::ErrorBuilder interacting with std types
281         // such as std::chrono::duration, so this function uses an indirection through
282         // std::ostringstream.
283         const auto lazyMessage = [&timing]() -> std::string {
284             std::ostringstream oss;
285             oss << "Timing::timeOnDevice (" << timing.timeOnDevice.value()
286                 << ") must not exceed Timing::timeInDriver (" << timing.timeInDriver.value() << ")";
287             return oss.str();
288         };
289         NN_VALIDATE(timing.timeOnDevice.value() <= timing.timeInDriver.value()) << lazyMessage();
290     }
291     return Version::ANDROID_Q;
292 }
293 
validateCapabilitiesPerformanceInfo(const Capabilities::PerformanceInfo & performanceInfo)294 Result<Version> validateCapabilitiesPerformanceInfo(
295         const Capabilities::PerformanceInfo& performanceInfo) {
296     NN_VALIDATE_GT(performanceInfo.execTime, 0.0f);
297     NN_VALIDATE_GT(performanceInfo.powerUsage, 0.0f);
298     return Version::ANDROID_OC_MR1;
299 }
300 
validateCapabilitiesOperandPerformance(const Capabilities::OperandPerformance & operandPerformance)301 Result<Version> validateCapabilitiesOperandPerformance(
302         const Capabilities::OperandPerformance& operandPerformance) {
303     auto version = NN_TRY(validateOperandType(operandPerformance.type));
304     return combineVersions(version,
305                            NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info)));
306 }
307 
validateCapabilitiesOperandPerformanceTable(const Capabilities::OperandPerformanceTable & operandPerformances)308 Result<Version> validateCapabilitiesOperandPerformanceTable(
309         const Capabilities::OperandPerformanceTable& operandPerformances) {
310     // OperandPerformanceTable's order was validated when it was created, and it is castable to any
311     // version. If an OperandType does not exist in the lower version being converted to, that
312     // OperandPerformance will be dropped.
313     NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance));
314     return Version::ANDROID_OC_MR1;
315 }
316 
validateCapabilities(const Capabilities & capabilities)317 Result<Version> validateCapabilities(const Capabilities& capabilities) {
318     auto version =
319             NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance));
320 
321     version = combineVersions(version,
322                               NN_TRY(validateCapabilitiesPerformanceInfo(
323                                       capabilities.relaxedFloat32toFloat16PerformanceScalar)));
324     version = combineVersions(version,
325                               NN_TRY(validateCapabilitiesPerformanceInfo(
326                                       capabilities.relaxedFloat32toFloat16PerformanceTensor)));
327     version = combineVersions(
328             version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance)));
329     version = combineVersions(
330             version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance)));
331 
332     return version;
333 }
334 
validateExtensionOperandTypeInformation(const Extension::OperandTypeInformation & operandTypeInformation)335 Result<Version> validateExtensionOperandTypeInformation(
336         const Extension::OperandTypeInformation& operandTypeInformation) {
337     NN_VALIDATE_GT(operandTypeInformation.byteSize, 0u);
338     return Version::ANDROID_Q;
339 }
340 
validateExtension(const Extension & extension)341 Result<Version> validateExtension(const Extension& extension) {
342     NN_VALIDATE(isValidExtensionName(extension.name));
343 
344     // Verify all OperandTypeInformations have unique types.
345     std::vector<uint16_t> types;
346     types.reserve(extension.operandTypes.size());
347     std::transform(extension.operandTypes.begin(), extension.operandTypes.end(),
348                    std::back_inserter(types),
349                    [](const Extension::OperandTypeInformation& operandTypeInformation) {
350                        return operandTypeInformation.type;
351                    });
352     std::sort(types.begin(), types.end());
353     const auto iter = std::adjacent_find(types.begin(), types.end());
354     NN_VALIDATE(iter == types.end()) << "Extension has duplicate type " << *iter;
355 
356     return combineVersions(Version::ANDROID_Q,
357                            NN_TRY(validateVector(extension.operandTypes,
358                                                  validateExtensionOperandTypeInformation)));
359 }
360 
validateExtensions(const std::vector<Extension> & extensions)361 Result<Version> validateExtensions(const std::vector<Extension>& extensions) {
362     const auto version = NN_TRY(validateVector(extensions, validateExtension));
363 
364     // Verify all extensions have unique names.
365     std::vector<std::reference_wrapper<const std::string>> names;
366     names.reserve(extensions.size());
367     std::transform(extensions.begin(), extensions.end(), std::back_inserter(names),
368                    [](const Extension& extension) { return std::cref(extension.name); });
369     std::sort(names.begin(), names.end(), std::less<std::string>{});
370     const auto nameIter =
371             std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
372     NN_VALIDATE(nameIter == names.end())
373             << "Two or more extensions have the duplicate name " << nameIter->get();
374 
375     return version;
376 }
377 
378 // Forward declaration of subgraph validation function.
379 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
380                                       std::optional<size_t> referencedIndex,
381                                       size_t operandValuesSize,
382                                       const std::vector<size_t>& poolSizes,
383                                       const std::vector<Model::Subgraph>& referenced,
384                                       std::vector<std::optional<Version>>* subgraphVersionCache);
385 
validateOperandDataLocation(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)386 Result<Version> validateOperandDataLocation(
387         const Operand& operand, size_t operandValuesSize, const std::vector<size_t>& poolSizes,
388         const std::vector<Model::Subgraph>& subgraphs,
389         std::vector<std::optional<Version>>* subgraphVersionCache) {
390     const DataLocation& location = operand.location;
391     NN_VALIDATE_EQ(location.padding, 0u)
392             << "DataLocation with a non-zero padding used in Model: " << location.padding;
393     switch (operand.lifetime) {
394         case Operand::LifeTime::CONSTANT_COPY:
395             NN_VALIDATE(location.pointer == kNullptrVariant)
396                     << "CONSTANT_COPY with a non-null pointer";
397             NN_VALIDATE_EQ(location.poolIndex, 0u)
398                     << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex;
399             // Do the addition using uint64_t to avoid potential wrap-around problems.
400             NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length,
401                            operandValuesSize)
402                     << "OperandValue location out of range.  Starts at " << location.offset
403                     << ", length " << location.length << ", max " << operandValuesSize;
404             return Version::ANDROID_OC_MR1;
405         case Operand::LifeTime::CONSTANT_REFERENCE:
406             NN_VALIDATE_LT(location.poolIndex, poolSizes.size());
407             // Do the addition using uint64_t to avoid potential wrap-around problems.
408             NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length,
409                            poolSizes[location.poolIndex])
410                     << "OperandValue location out of range.  Starts at " << location.offset
411                     << ", length " << location.length << ", max " << poolSizes[location.poolIndex];
412             return Version::ANDROID_OC_MR1;
413         case Operand::LifeTime::TEMPORARY_VARIABLE:
414         case Operand::LifeTime::SUBGRAPH_INPUT:
415         case Operand::LifeTime::SUBGRAPH_OUTPUT:
416         case Operand::LifeTime::NO_VALUE:
417             NN_VALIDATE(location.pointer == kNullptrVariant)
418                     << "Unexpected pointer value for operand of lifetime " << operand.lifetime;
419             NN_VALIDATE_EQ(location.poolIndex, 0u)
420                     << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime "
421                     << operand.lifetime;
422             NN_VALIDATE_EQ(location.offset, 0u) << "Unexpected offset " << location.offset
423                                                 << " for operand of lifetime " << operand.lifetime;
424             NN_VALIDATE_EQ(location.length, 0u) << "Unexpected length " << location.length
425                                                 << " for operand of lifetime " << operand.lifetime;
426             return Version::ANDROID_OC_MR1;
427         case Operand::LifeTime::SUBGRAPH: {
428             NN_VALIDATE(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer";
429             NN_VALIDATE_EQ(location.poolIndex, 0u)
430                     << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex;
431             NN_VALIDATE_LT(location.offset, subgraphs.size())
432                     << "Subgraph index out of range: " << location.offset
433                     << " >= " << subgraphs.size();
434             NN_VALIDATE_EQ(location.length, 0u)
435                     << "SUBGRAPH with a non-zero length " << location.length;
436             const auto version = NN_TRY(validateModelSubgraph(
437                     subgraphs[location.offset], location.offset, operandValuesSize, poolSizes,
438                     subgraphs, subgraphVersionCache));
439             return combineVersions(version, Version::ANDROID_R);
440         }
441         case Operand::LifeTime::POINTER: {
442             const bool nonNull =
443                     std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer);
444             NN_VALIDATE(nonNull) << "POINTER with a null pointer";
445             NN_VALIDATE_EQ(location.poolIndex, 0u)
446                     << "POINTER with a non-zero poolIndex " << location.poolIndex;
447             NN_VALIDATE_EQ(location.offset, 0u)
448                     << "POINTER with a non-zero offset " << location.offset;
449             return Version::ANDROID_OC_MR1;
450         }
451     }
452     NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
453 }
454 
validateOperandDimensions(const Operand & operand)455 Result<Version> validateOperandDimensions(const Operand& operand) {
456     switch (operand.type) {
457         case OperandType::FLOAT32:
458         case OperandType::INT32:
459         case OperandType::UINT32:
460         case OperandType::BOOL:
461         case OperandType::FLOAT16:
462         case OperandType::SUBGRAPH:
463         case OperandType::OEM:
464             NN_VALIDATE(operand.dimensions.empty())
465                     << "Scalar data has dimensions of rank " << operand.dimensions.size();
466             return Version::ANDROID_OC_MR1;
467         case OperandType::TENSOR_FLOAT32:
468         case OperandType::TENSOR_INT32:
469         case OperandType::TENSOR_QUANT8_ASYMM:
470         case OperandType::TENSOR_QUANT16_SYMM:
471         case OperandType::TENSOR_FLOAT16:
472         case OperandType::TENSOR_BOOL8:
473         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
474         case OperandType::TENSOR_QUANT16_ASYMM:
475         case OperandType::TENSOR_QUANT8_SYMM:
476         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
477         case OperandType::TENSOR_OEM_BYTE: {
478             if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
479                 operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
480                 operand.lifetime == Operand::LifeTime::POINTER) {
481                 NN_VALIDATE(!operand.dimensions.empty())
482                         << "Tensor has lifetime of " << operand.lifetime
483                         << " but dimensions of rank 0";
484                 const auto size = getNonExtensionSize(operand);
485                 NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow";
486                 NN_VALIDATE_NE(size.value(), 0u) << "Tensor has at least one unknown dimension";
487             }
488             // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before
489             // Android Q?
490             if (operand.dimensions.empty()) {
491                 // Unspecified rank was added in Android Q.
492                 return Version::ANDROID_Q;
493             }
494             return Version::ANDROID_OC_MR1;
495         }
496     }
497     if (isExtension(operand.type)) {
498         // Extension types were added in Android Q.
499         return Version::ANDROID_Q;
500     }
501     NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
502 }
503 
validateOperandScale(const Operand & operand)504 Result<Version> validateOperandScale(const Operand& operand) {
505     switch (operand.type) {
506         case OperandType::FLOAT32:
507         case OperandType::INT32:
508         case OperandType::UINT32:
509         case OperandType::TENSOR_FLOAT32:
510         case OperandType::BOOL:
511         case OperandType::TENSOR_FLOAT16:
512         case OperandType::TENSOR_BOOL8:
513         case OperandType::FLOAT16:
514         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
515         case OperandType::SUBGRAPH:
516             NN_VALIDATE_EQ(operand.scale, 0.0f)
517                     << "Operand of type " << operand.type << " with a non-zero scale ("
518                     << operand.scale << ")";
519             return Version::ANDROID_OC_MR1;
520         case OperandType::TENSOR_INT32:
521             // TENSOR_INT32 may be used with or without scale, depending on the operation.
522             // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
523             NN_VALIDATE_GE(operand.scale, 0.0f)
524                     << "Operand of type " << operand.type << " with a negative scale";
525             return Version::ANDROID_OC_MR1;
526         case OperandType::TENSOR_QUANT8_ASYMM:
527         case OperandType::TENSOR_QUANT16_SYMM:
528         case OperandType::TENSOR_QUANT16_ASYMM:
529         case OperandType::TENSOR_QUANT8_SYMM:
530         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
531             NN_VALIDATE_GT(operand.scale, 0.0f)
532                     << "Operand of type " << operand.type << " with a non-positive scale";
533             return Version::ANDROID_OC_MR1;
534         case OperandType::OEM:
535         case OperandType::TENSOR_OEM_BYTE:
536             // No validation for OEM types.
537             return Version::ANDROID_OC_MR1;
538     }
539     if (isExtension(operand.type)) {
540         NN_VALIDATE_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type
541                                             << " with a non-zero scale (" << operand.scale << ")";
542         return Version::ANDROID_Q;
543     }
544     NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
545 }
546 
validateOperandZeroPoint(const Operand & operand)547 Result<Version> validateOperandZeroPoint(const Operand& operand) {
548     switch (operand.type) {
549         case OperandType::FLOAT32:
550         case OperandType::INT32:
551         case OperandType::UINT32:
552         case OperandType::TENSOR_FLOAT32:
553         case OperandType::TENSOR_INT32:
554         case OperandType::BOOL:
555         case OperandType::TENSOR_FLOAT16:
556         case OperandType::TENSOR_BOOL8:
557         case OperandType::FLOAT16:
558         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
559         case OperandType::TENSOR_QUANT8_SYMM:
560         case OperandType::SUBGRAPH:
561             NN_VALIDATE_EQ(operand.zeroPoint, 0)
562                     << "Operand of type " << operand.type << " with a non-zero zeroPoint "
563                     << operand.zeroPoint;
564             return Version::ANDROID_OC_MR1;
565         case OperandType::TENSOR_QUANT8_ASYMM:
566             NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 255)
567                     << "Operand of type " << operand.type << " with an invalid zeroPoint "
568                     << operand.zeroPoint << ", must be in range [0, 255]";
569             return Version::ANDROID_OC_MR1;
570         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
571             NN_VALIDATE(operand.zeroPoint >= -128 && operand.zeroPoint <= 127)
572                     << "Operand of type " << operand.type << " with an invalid zeroPoint "
573                     << operand.zeroPoint << ", must be in range [-128, 127]";
574             return Version::ANDROID_OC_MR1;
575         case OperandType::TENSOR_QUANT16_ASYMM:
576             NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535)
577                     << "Operand of type " << operand.type << " with an invalid zeroPoint "
578                     << operand.zeroPoint << ", must be in range [0, 65535]";
579             return Version::ANDROID_OC_MR1;
580         case OperandType::TENSOR_QUANT16_SYMM:
581             NN_VALIDATE_EQ(operand.zeroPoint, 0)
582                     << "Operand of type " << operand.type << " with a non-zero zeroPoint "
583                     << operand.zeroPoint;
584             return Version::ANDROID_OC_MR1;
585         case OperandType::OEM:
586         case OperandType::TENSOR_OEM_BYTE:
587             // No validation for OEM types.
588             return Version::ANDROID_OC_MR1;
589     }
590     if (isExtension(operand.type)) {
591         NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type
592                                              << " with a non-zero zeroPoint " << operand.zeroPoint;
593         return Version::ANDROID_Q;
594     }
595     NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
596 }
597 
validateOperandExtraParams(const Operand & operand)598 Result<Version> validateOperandExtraParams(const Operand& operand) {
599     switch (operand.type) {
600         case OperandType::FLOAT32:
601         case OperandType::INT32:
602         case OperandType::UINT32:
603         case OperandType::TENSOR_FLOAT32:
604         case OperandType::TENSOR_INT32:
605         case OperandType::TENSOR_QUANT8_ASYMM:
606         case OperandType::BOOL:
607         case OperandType::TENSOR_QUANT16_SYMM:
608         case OperandType::TENSOR_FLOAT16:
609         case OperandType::TENSOR_BOOL8:
610         case OperandType::FLOAT16:
611         case OperandType::TENSOR_QUANT16_ASYMM:
612         case OperandType::TENSOR_QUANT8_SYMM:
613         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
614         case OperandType::SUBGRAPH:
615             NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams))
616                     << "Operand of type " << operand.type
617                     << " has extraParams when there must be none";
618             return Version::ANDROID_OC_MR1;
619         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
620             NN_VALIDATE(
621                     std::holds_alternative<Operand::SymmPerChannelQuantParams>(operand.extraParams))
622                     << "Operand of type " << operand.type
623                     << " without a Channel Quantization params";
624             const auto& channelQuant =
625                     std::get<Operand::SymmPerChannelQuantParams>(operand.extraParams);
626 
627             const size_t count = operand.dimensions.size();
628             NN_VALIDATE_LT(channelQuant.channelDim, count)
629                     << "Operand of type " << operand.type
630                     << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
631                     << ", must be valid dimension index in range [0, " << count << ")";
632             const uint32_t expected = operand.dimensions[channelQuant.channelDim];
633             NN_VALIDATE_EQ(channelQuant.scales.size(), expected)
634                     << "Operand of type " << operand.type << " with a wrong-sized scales, expected "
635                     << expected << " was " << channelQuant.scales.size();
636             NN_VALIDATE_NE(expected, 0u)
637                     << "Operand of type " << operand.type << " channel dimension "
638                     << channelQuant.channelDim << " is underspecified (can't be 0)";
639             for (uint32_t i = 0; i < expected; ++i) {
640                 NN_VALIDATE_GT(channelQuant.scales[i], 0.0f)
641                         << "Operand of type " << operand.type
642                         << " with a non-positive value in scales[" << i
643                         << "]=" << channelQuant.scales[i];
644             }
645             return Version::ANDROID_Q;
646         }
647         case OperandType::OEM:
648         case OperandType::TENSOR_OEM_BYTE:
649             // No validation for OEM types.
650             return Version::ANDROID_OC_MR1;
651     }
652     if (isExtension(operand.type)) {
653         NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams) ||
654                     std::holds_alternative<Operand::ExtensionParams>(operand.extraParams))
655                 << "Extension operand of type " << operand.type
656                 << " must not have SymmPerChannelQuant extraParams";
657         return Version::ANDROID_OC_MR1;
658     }
659     NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
660 }
661 
validateOperand(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)662 Result<Version> validateOperand(const Operand& operand, size_t operandValuesSize,
663                                 const std::vector<size_t>& poolSizes,
664                                 const std::vector<Model::Subgraph>& subgraphs,
665                                 std::vector<std::optional<Version>>* subgraphVersionCache) {
666     auto version = NN_TRY(validateOperandType(operand.type));
667     version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand)));
668     version = combineVersions(version, NN_TRY(validateOperandDimensions(operand)));
669     version = combineVersions(version, NN_TRY(validateOperandScale(operand)));
670     version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand)));
671     version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand)));
672     version = combineVersions(
673             version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize, poolSizes,
674                                                         subgraphs, subgraphVersionCache)));
675 
676     // For constants, validate that the length is as expected. The other lifetimes
677     // expect the length to be 0. Don't validate for OEM types.
678     if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
679         operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
680         operand.lifetime == Operand::LifeTime::POINTER) {
681         if (!isExtension(operand.type) && operand.type != OperandType::OEM &&
682             operand.type != OperandType::TENSOR_OEM_BYTE) {
683             const auto expectedLength = getNonExtensionSize(operand).value();
684             NN_VALIDATE_EQ(operand.location.length, expectedLength)
685                     << "For operand " << operand.type << " expected a size of " << expectedLength
686                     << " but got " << operand.location.length;
687         }
688     }
689 
690     return version;
691 }
692 
validateOperands(const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)693 Result<std::vector<Version>> validateOperands(
694         const std::vector<Operand>& operands, size_t operandValuesSize,
695         const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
696         std::vector<std::optional<Version>>* subgraphVersionCache) {
697     std::vector<Version> versions;
698     versions.reserve(operands.size());
699     for (size_t i = 0; i < operands.size(); ++i) {
700         auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphs,
701                                       subgraphVersionCache);
702         if (!result.has_value()) {
703             return error() << std::move(result).error() << " for operand " << i;
704         }
705         versions.push_back(result.value());
706     }
707     return versions;
708 }
709 
710 // Forward declaration.
711 Result<Version> validateOperationIncludingOperandVersions(
712         const Operation& operation, const std::vector<Operand>& operands,
713         const std::vector<Version>& operandVersions, const std::vector<Model::Subgraph>& subgraphs);
714 
validateOperations(const std::vector<Operation> & operations,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)715 Result<Version> validateOperations(const std::vector<Operation>& operations,
716                                    const std::vector<Operand>& operands,
717                                    const std::vector<Version>& operandVersions,
718                                    const std::vector<Model::Subgraph>& subgraphs) {
719     auto version = Version::ANDROID_OC_MR1;
720     for (size_t i = 0; i < operations.size(); ++i) {
721         auto result = validateOperationIncludingOperandVersions(operations[i], operands,
722                                                                 operandVersions, subgraphs);
723         if (!result.has_value()) {
724             return error() << std::move(result).error() << " for operation " << i;
725         }
726         version = combineVersions(version, result.value());
727     }
728     return version;
729 }
730 
validateHandle(const Handle & handle)731 Result<Version> validateHandle(const Handle& handle) {
732     NN_VALIDATE(std::all_of(handle.fds.begin(), handle.fds.end(),
733                             [](const base::unique_fd& fd) { return fd.ok(); }));
734     return Version::ANDROID_OC_MR1;
735 }
736 
validateSharedHandle(const SharedHandle & handle)737 Result<Version> validateSharedHandle(const SharedHandle& handle) {
738     NN_VALIDATE(handle != nullptr);
739     return validateHandle(*handle);
740 }
741 
validateMemory(const Memory::Ashmem & memory)742 Result<Version> validateMemory(const Memory::Ashmem& memory) {
743     NN_VALIDATE(memory.fd.ok());
744     NN_VALIDATE_NE(memory.size, 0u);
745     return Version::ANDROID_OC_MR1;
746 }
747 
validateMemory(const Memory::Fd & memory)748 Result<Version> validateMemory(const Memory::Fd& memory) {
749     NN_VALIDATE(memory.fd.ok());
750     NN_VALIDATE_NE(memory.size, 0u);
751 
752     // `prot` is allowed to be either PROT_NONE (which has a value of 0) or the bitwise OR of either
753     // PROT_READ or PROT_WRITE. If any other bits are set, the `prot` field is invalid.
754     constexpr int kAllowedBits = PROT_READ | PROT_WRITE;
755     NN_VALIDATE_EQ(memory.prot & ~kAllowedBits, 0);
756 
757     return Version::ANDROID_OC_MR1;
758 }
759 
validateMemory(const Memory::HardwareBuffer & memory)760 Result<Version> validateMemory(const Memory::HardwareBuffer& memory) {
761     NN_VALIDATE(memory.handle.get() != nullptr);
762     return Version::ANDROID_Q;
763 }
764 
validateMemory(const Memory::Unknown & memory)765 Result<Version> validateMemory(const Memory::Unknown& memory) {
766     NN_TRY(validateHandle(memory.handle));
767     return Version::ANDROID_Q;
768 }
769 
validateSharedMemory(const SharedMemory & memory)770 Result<Version> validateSharedMemory(const SharedMemory& memory) {
771     NN_VALIDATE(memory != nullptr);
772     return std::visit([](const auto& x) { return validateMemory(x); }, memory->handle);
773 }
774 
validateModelSubgraphInputOutputs(const std::vector<uint32_t> & indexes,const std::vector<Operand> & operands,Operand::LifeTime lifetime)775 Result<void> validateModelSubgraphInputOutputs(const std::vector<uint32_t>& indexes,
776                                                const std::vector<Operand>& operands,
777                                                Operand::LifeTime lifetime) {
778     const size_t operandCount = operands.size();
779     for (uint32_t i : indexes) {
780         NN_VALIDATE_LT(i, operandCount)
781                 << "Model " << lifetime << " input or output index out of range: " << i << "/"
782                 << operandCount;
783         const Operand& operand = operands[i];
784         NN_VALIDATE_EQ(operand.lifetime, lifetime)
785                 << "Model " << lifetime << " operand " << i << " has lifetime of "
786                 << operand.lifetime << " instead of the expected " << lifetime;
787     }
788 
789     std::vector<uint32_t> sortedIndexes = indexes;
790     std::sort(sortedIndexes.begin(), sortedIndexes.end());
791     const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
792     NN_VALIDATE(iter == sortedIndexes.end())
793             << "Model input or output occurs multiple times: " << *iter;
794 
795     for (size_t i = 0; i < operands.size(); ++i) {
796         if (operands[i].lifetime == lifetime) {
797             const auto containsIndex = [&sortedIndexes](size_t index) {
798                 return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index);
799             };
800             NN_VALIDATE(containsIndex(i))
801                     << "Operand " << i << " marked as " << lifetime
802                     << " but is not included in Model input or output indexes";
803         }
804     }
805 
806     return {};
807 }
808 
validateExecutionOrder(const Model::Subgraph & subgraph)809 Result<void> validateExecutionOrder(const Model::Subgraph& subgraph) {
810     // Either the operand has a known value before model execution begins, or we've seen a writer
811     // for this operand while walking operands in execution order. Initialize to known operands.
812     std::vector<bool> operandValueKnown;
813     operandValueKnown.reserve(subgraph.operands.size());
814     std::transform(subgraph.operands.begin(), subgraph.operands.end(),
815                    std::back_inserter(operandValueKnown), [](const Operand& operand) {
816                        return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE &&
817                               operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT;
818                    });
819 
820     // Validate that operations are sorted into execution order.
821     //
822     // If there is a cycle in the graph, the operations will not
823     // appear to be sorted into execution order: Some operation will
824     // have an input for which operandValueKnown[] is false.
825     for (size_t i = 0; i < subgraph.operations.size(); ++i) {
826         const auto& operation = subgraph.operations[i];
827 
828         for (size_t j = 0; j < operation.inputs.size(); ++j) {
829             const uint32_t k = operation.inputs[j];
830             NN_VALIDATE(operandValueKnown[k]) << "Operation " << i << " input " << j << " (operand "
831                                               << k << ") is read before it is written";
832         }
833 
834         for (size_t j = 0; j < operation.outputs.size(); ++j) {
835             const uint32_t k = operation.outputs[j];
836             // Assuming validateOperations() has not returned an error, we know that this output is
837             // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be
838             // true is if we've already seen a writer for this operand.
839             NN_VALIDATE(!operandValueKnown[k]) << "Operation " << i << " output " << j
840                                                << " (operand " << k << ") has already been written";
841             operandValueKnown[k] = true;
842         }
843     }
844 
845     // Verify all operands are written.
846     for (size_t i = 0; i < subgraph.operands.size(); ++i) {
847         NN_VALIDATE(operandValueKnown[i]) << "Operand " << i << " is never written";
848     }
849 
850     // TODO(b/77871786): verify that every operation has at least one output operand that is read?
851 
852     return {};
853 }
854 
855 // Validate a subgraph, ensuring all subgraphs it depends on are also validated.
856 //
857 // `referencedIndex` is empty if the subgraph being validated is the main subgraph, otherwise it is
858 // the index of the referenced subgraph being validated.
859 //
860 // referenced[i] and (*subgraphVersionCache)[i] correspond to the same subgraph, and therefore
861 // `referenced` and `subgraphVersionCache` must have the same length.
validateModelSubgraph(const Model::Subgraph & subgraph,std::optional<size_t> referencedIndex,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & referenced,std::vector<std::optional<Version>> * subgraphVersionCache)862 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
863                                       std::optional<size_t> referencedIndex,
864                                       size_t operandValuesSize,
865                                       const std::vector<size_t>& poolSizes,
866                                       const std::vector<Model::Subgraph>& referenced,
867                                       std::vector<std::optional<Version>>* subgraphVersionCache) {
868     CHECK(subgraphVersionCache != nullptr);
869     CHECK_EQ(referenced.size(), subgraphVersionCache->size());
870 
871     // Quickly return if the current subgraph has already been checked for its version.
872     if (referencedIndex.has_value()) {
873         if (auto version = subgraphVersionCache->at(*referencedIndex)) {
874             return *version;
875         }
876     }
877 
878     NN_VALIDATE(!subgraph.operands.empty());
879     NN_VALIDATE(!subgraph.operations.empty());
880     // TODO(b/173780642): Clarify whether subgraphs with no inputs or outputs are valid.
881     // NN_VALIDATE(!subgraph.inputIndexes.empty());
882     // NN_VALIDATE(!subgraph.outputIndexes.empty());
883 
884     const auto operandVersions = NN_TRY(validateOperands(
885             subgraph.operands, operandValuesSize, poolSizes, referenced, subgraphVersionCache));
886     const auto operationsVersion = NN_TRY(validateOperations(subgraph.operations, subgraph.operands,
887                                                              operandVersions, referenced));
888 
889     // Accumulate the versions from all operands and operations.
890     const auto version = std::accumulate(operandVersions.begin(), operandVersions.end(),
891                                          operationsVersion, combineVersions);
892 
893     NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands,
894                                              Operand::LifeTime::SUBGRAPH_INPUT));
895     NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands,
896                                              Operand::LifeTime::SUBGRAPH_OUTPUT));
897 
898     NN_TRY(validateExecutionOrder(subgraph));
899 
900     // Mark the current subgraph as having already been validated so the caller can quickly return
901     // if this subgraph is checked again.
902     if (referencedIndex.has_value()) {
903         subgraphVersionCache->at(*referencedIndex) = version;
904     }
905     return version;
906 }
907 
validateModelExtensionNamesAndPrefixes(const std::vector<Model::ExtensionNameAndPrefix> & extensionNamesAndPrefixes)908 Result<Version> validateModelExtensionNamesAndPrefixes(
909         const std::vector<Model::ExtensionNameAndPrefix>& extensionNamesAndPrefixes) {
910     for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) {
911         NN_VALIDATE(isValidExtensionName(extensionNameAndPrefix.name));
912     }
913 
914     std::vector<std::reference_wrapper<const std::string>> names;
915     names.reserve(extensionNamesAndPrefixes.size());
916     std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
917                    std::back_inserter(names),
918                    [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
919                        return std::cref(extensionNameAndPrefix.name);
920                    });
921     std::sort(names.begin(), names.end(), std::less<std::string>{});
922     const auto nameIter =
923             std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
924     NN_VALIDATE(nameIter == names.end())
925             << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get();
926 
927     std::vector<uint16_t> types;
928     types.reserve(extensionNamesAndPrefixes.size());
929     std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
930                    std::back_inserter(types),
931                    [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
932                        return extensionNameAndPrefix.prefix;
933                    });
934     std::sort(types.begin(), types.end());
935     const auto typeIter = std::adjacent_find(types.begin(), types.end());
936     NN_VALIDATE(typeIter == types.end())
937             << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter;
938 
939     const bool hasExtensions = !extensionNamesAndPrefixes.empty();
940     return hasExtensions ? Version::ANDROID_Q : Version::ANDROID_OC_MR1;
941 }
942 
943 // Makes sure the model does not contain subgraph reference cycles.
944 //
945 // This function verifies that referencedSubgraphs[subgraphIndex] and any subgraphs it refences do
946 // not contain any reference cycles. `path` is used to keep track of which referenced subgraphs have
947 // already been visited in the current recursive reference path. `verified` is a cache to keep track
948 // of which referenced subgraphs have already been verified not to form reference cycles.
949 //
950 // referencedSubgraphs[i], (*path)[i], and (*verified)[i] all correspond to the same subgraph, and
951 // therefore `referencedSubgraphs`, `path`, and `verified` must all have the same length.
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs,uint32_t subgraphIndex,std::vector<bool> * path,std::vector<bool> * verified)952 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs,
953                                     uint32_t subgraphIndex, std::vector<bool>* path,
954                                     std::vector<bool>* verified) {
955     CHECK(path != nullptr);
956     CHECK(verified != nullptr);
957     CHECK_EQ(referencedSubgraphs.size(), path->size());
958     CHECK_EQ(referencedSubgraphs.size(), verified->size());
959     const auto& subgraph = referencedSubgraphs.at(subgraphIndex);
960 
961     // Quickly return if the current subgraph has already been verified to have no reference cycles.
962     if ((*verified)[subgraphIndex]) {
963         return {};
964     }
965 
966     // Add the current subgraph to the path (making sure that it is not already part of the path),
967     // and verify that all subgraphs this subgraph references do not contain cycles. The current
968     // subgraph is removed from the path only after all subgraphs this subgraph references have been
969     // checked.
970     NN_VALIDATE((*path)[subgraphIndex] == false) << "Model contains a circular subgraph reference";
971     (*path)[subgraphIndex] = true;
972     for (const Operand& operand : subgraph.operands) {
973         if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
974             const uint32_t refSubgraphIndex = operand.location.offset;
975             NN_TRY(checkNoReferenceCycles(referencedSubgraphs, refSubgraphIndex, path, verified));
976         }
977     }
978     (*path)[subgraphIndex] = false;
979 
980     // Mark the current subgraph as having already been verified so the caller can quickly return if
981     // this subgraph is checked again.
982     (*verified)[subgraphIndex] = true;
983     return {};
984 }
985 
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs)986 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs) {
987     const size_t count = referencedSubgraphs.size();
988     std::vector<bool> path(count);
989     std::vector<bool> verified(count);
990     for (size_t i = 0; i < count; ++i) {
991         NN_TRY(checkNoReferenceCycles(referencedSubgraphs, i, &path, &verified));
992     }
993     return {};
994 }
995 
validateModel(const Model & model)996 Result<Version> validateModel(const Model& model) {
997     auto version = NN_TRY(validateVector(model.pools, validateSharedMemory));
998     version = combineVersions(
999             version, NN_TRY(validateModelExtensionNamesAndPrefixes(model.extensionNameToPrefix)));
1000 
1001     // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the
1002     // execution stricter.
1003 
1004     // Referenced models were introduced in Android R.
1005     const bool hasReferencedModels = !model.referenced.empty();
1006     const auto referenceModelVersion =
1007             hasReferencedModels ? Version::ANDROID_R : Version::ANDROID_OC_MR1;
1008     version = combineVersions(version, referenceModelVersion);
1009 
1010     // Ensure that there are no cycles formed by the subgraphs.
1011     NN_TRY(checkNoReferenceCycles(model.referenced));
1012 
1013     // Get memory sizes.
1014     const auto [operandValuesSize, poolSizes] = getMemorySizes(model);
1015 
1016     // Validate referenced subgraphs.
1017     auto subgraphVersionCache = std::vector<std::optional<Version>>(model.referenced.size());
1018     for (size_t referencedIndex = 0; referencedIndex < model.referenced.size(); ++referencedIndex) {
1019         const auto& subgraph = model.referenced[referencedIndex];
1020         const auto subgraphVersion =
1021                 NN_TRY(validateModelSubgraph(subgraph, referencedIndex, operandValuesSize,
1022                                              poolSizes, model.referenced, &subgraphVersionCache));
1023         version = combineVersions(version, subgraphVersion);
1024     }
1025 
1026     // Validate main subgraph.
1027     const auto subgraphVersion =
1028             NN_TRY(validateModelSubgraph(model.main, std::nullopt, operandValuesSize, poolSizes,
1029                                          model.referenced, &subgraphVersionCache));
1030     version = combineVersions(version, subgraphVersion);
1031 
1032     return version;
1033 }
1034 
validateBufferDesc(const BufferDesc & bufferDesc)1035 Result<Version> validateBufferDesc(const BufferDesc& bufferDesc) {
1036     // An empty BufferDesc is the default value, so it is implicitly valid for all versions.
1037     return bufferDesc.dimensions.empty() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R;
1038 }
1039 
validateBufferRole(const BufferRole & bufferRole)1040 Result<Version> validateBufferRole(const BufferRole& bufferRole) {
1041     NN_VALIDATE_GT(bufferRole.probability, 0.0f);
1042     NN_VALIDATE_LE(bufferRole.probability, 1.0f);
1043     return Version::ANDROID_R;
1044 }
1045 
validateRequestArgument(const Request::Argument & requestArgument,const std::vector<size_t> & memorySizes,bool isOutput)1046 Result<Version> validateRequestArgument(const Request::Argument& requestArgument,
1047                                         const std::vector<size_t>& memorySizes, bool isOutput) {
1048     const auto lifetime = requestArgument.lifetime;
1049     const auto& location = requestArgument.location;
1050     const auto& dimensions = requestArgument.dimensions;
1051 
1052     switch (lifetime) {
1053         case Request::Argument::LifeTime::POOL: {
1054             NN_VALIDATE(location.pointer == kNullptrVariant);
1055             NN_VALIDATE_LT(location.poolIndex, memorySizes.size());
1056             // Do the addition using uint64_t to avoid potential wrap-around problems.
1057             const auto lastPosition =
1058                     static_cast<uint64_t>(location.offset) + location.length + location.padding;
1059             const auto memorySize = memorySizes[location.poolIndex];
1060             NN_VALIDATE_LE(lastPosition, memorySize);
1061             if (memorySize > 0) {
1062                 // Must specify a positive length if the memory pool has a known size.
1063                 NN_VALIDATE_GT(location.length, 0u);
1064             }
1065             return Version::ANDROID_OC_MR1;
1066         }
1067         case Request::Argument::LifeTime::NO_VALUE:
1068             NN_VALIDATE(location.pointer == kNullptrVariant);
1069             NN_VALIDATE_EQ(location.poolIndex, 0u);
1070             NN_VALIDATE_EQ(location.offset, 0u);
1071             NN_VALIDATE_EQ(location.length, 0u);
1072             NN_VALIDATE_EQ(location.padding, 0u);
1073             NN_VALIDATE(dimensions.empty());
1074             return Version::ANDROID_OC_MR1;
1075         case Request::Argument::LifeTime::POINTER: {
1076             const bool isNullptr =
1077                     std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
1078             NN_VALIDATE(!isNullptr);
1079             NN_VALIDATE_EQ(location.poolIndex, 0u);
1080             NN_VALIDATE_EQ(location.offset, 0u);
1081             NN_VALIDATE_NE(location.length, 0u);
1082             if (isOutput) {
1083                 NN_VALIDATE(std::holds_alternative<void*>(location.pointer));
1084             }
1085             return Version::ANDROID_OC_MR1;
1086         }
1087     }
1088     NN_VALIDATE_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime;
1089 }
1090 
validateRequestMemoryPool(const Request::MemoryPool & memoryPool)1091 Result<Version> validateRequestMemoryPool(const Request::MemoryPool& memoryPool) {
1092     if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
1093         NN_VALIDATE(std::get<Request::MemoryDomainToken>(memoryPool) != kInvalidMemoryDomainToken);
1094         return Version::ANDROID_R;
1095     }
1096     if (std::holds_alternative<SharedBuffer>(memoryPool)) {
1097         NN_VALIDATE(std::get<SharedBuffer>(memoryPool) != nullptr);
1098         return Version::ANDROID_R;
1099     }
1100     return validateSharedMemory(std::get<SharedMemory>(memoryPool));
1101 }
1102 
validateRequest(const Request & request)1103 Result<Version> validateRequest(const Request& request) {
1104     auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool));
1105 
1106     // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0.
1107     std::vector<size_t> memorySizes;
1108     memorySizes.reserve(request.pools.size());
1109     std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes),
1110                    [](const Request::MemoryPool& memoryPool) {
1111                        const auto* memory = std::get_if<SharedMemory>(&memoryPool);
1112                        return memory != nullptr ? getSize(*memory) : 0;
1113                    });
1114 
1115     for (size_t i = 0; i < request.inputs.size(); ++i) {
1116         const auto& input = request.inputs[i];
1117         auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false);
1118         if (!result.has_value()) {
1119             return error() << std::move(result).error() << " for input RequestArgument " << i;
1120         }
1121         version = combineVersions(version, result.value());
1122     }
1123     for (size_t i = 0; i < request.outputs.size(); ++i) {
1124         const auto& output = request.outputs[i];
1125         auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true);
1126         if (!result.has_value()) {
1127             return error() << std::move(result).error() << " for output RequestArgument " << i;
1128         }
1129         version = combineVersions(version, result.value());
1130     }
1131 
1132     return version;
1133 }
1134 
validateOptionalTimePoint(const OptionalTimePoint & optionalTimePoint)1135 Result<Version> validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) {
1136     if (optionalTimePoint.has_value()) {
1137         NN_VALIDATE_GE(optionalTimePoint->time_since_epoch().count(), 0);
1138     }
1139     // An omitted time point is the default value, so it is implicitly valid for all versions.
1140     return !optionalTimePoint.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R;
1141 }
1142 
validateOptionalTimeoutDuration(const OptionalDuration & optionalTimeoutDuration)1143 Result<Version> validateOptionalTimeoutDuration(const OptionalDuration& optionalTimeoutDuration) {
1144     if (optionalTimeoutDuration.has_value()) {
1145         NN_VALIDATE_GE(optionalTimeoutDuration->count(), 0);
1146     }
1147     // An omitted duration is the default value, so it is implicitly valid for all versions.
1148     return !optionalTimeoutDuration.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R;
1149 }
1150 
validateCacheToken(const CacheToken & cacheToken)1151 Result<Version> validateCacheToken(const CacheToken& cacheToken) {
1152     // A CacheToken of 0 is the default value, so it is implicitly valid for all versions.
1153     constexpr auto kDefaultCacheToken = CacheToken{};
1154     return cacheToken == kDefaultCacheToken ? Version::ANDROID_OC_MR1 : Version::ANDROID_Q;
1155 }
1156 
validateSyncFence(const SyncFence & syncFence)1157 Result<Version> validateSyncFence(const SyncFence& syncFence) {
1158     // The absence of a sync fence is implicitly valid for all versions.
1159     if (!syncFence.hasFd()) {
1160         return Version::ANDROID_OC_MR1;
1161     }
1162     NN_VALIDATE_GE(syncFence.getFd(), 0);
1163     return Version::ANDROID_R;
1164 }
1165 
validateRequestArgumentsForModel(const std::vector<Request::Argument> & requestArguments,const std::vector<uint32_t> & operandIndexes,const std::vector<Operand> & operands,bool isOutput,bool allowUnspecifiedOutput)1166 Result<Version> validateRequestArgumentsForModel(
1167         const std::vector<Request::Argument>& requestArguments,
1168         const std::vector<uint32_t>& operandIndexes, const std::vector<Operand>& operands,
1169         bool isOutput, bool allowUnspecifiedOutput) {
1170     auto version = Version::ANDROID_OC_MR1;
1171     // The request should specify as many arguments as were described in the model.
1172     const std::string_view type = isOutput ? "output" : "input";
1173     const size_t requestArgumentCount = requestArguments.size();
1174     NN_VALIDATE_EQ(requestArgumentCount, operandIndexes.size())
1175             << "Request specifies " << requestArgumentCount << " " << type << "s but the model has "
1176             << operandIndexes.size();
1177     for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
1178          requestArgumentIndex++) {
1179         const Request::Argument& requestArgument = requestArguments[requestArgumentIndex];
1180         // Get the operand index for this argument. We extract it from the list
1181         // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
1182         // We assume in this function that the model has been validated already.
1183         const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
1184         const Operand& operand = operands[operandIndex];
1185         if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) {
1186             const bool isExtensionType = isExtension(operand.type);
1187             // If the argument specified a dimension, validate it.
1188             uint32_t modelRank = operand.dimensions.size();
1189             uint32_t requestRank = requestArgument.dimensions.size();
1190             if (requestRank == 0) {
1191                 // NOTE: validateRequestArguments cannot validate unknown tensor rank with
1192                 // extension operand type.
1193                 if (!isExtensionType && !isNonExtensionScalar(operand.type)) {
1194                     if (modelRank <= 0) {
1195                         NN_VALIDATE(isOutput)
1196                                 << "Model has unknown input rank but the request does not "
1197                                    "specify the rank.";
1198                         NN_VALIDATE(allowUnspecifiedOutput)
1199                                 << "Model has unknown output rank and request does not specify it.";
1200                         // Unspecified output dimensions introduced in Android Q.
1201                         version = combineVersions(version, Version::ANDROID_Q);
1202                     }
1203                 }
1204                 // Validate that all the dimensions are specified in the model.
1205                 for (size_t i = 0; i < modelRank; i++) {
1206                     if (operand.dimensions[i] == 0) {
1207                         NN_VALIDATE(isOutput && allowUnspecifiedOutput)
1208                                 << "Model has dimension " << i
1209                                 << " set to 0 but the request does not specify the dimension.";
1210                         // Unspecified output dimensions introduced in Android Q.
1211                         version = combineVersions(version, Version::ANDROID_Q);
1212                     }
1213                 }
1214             } else {
1215                 NN_VALIDATE(modelRank == 0 || requestRank == modelRank)
1216                         << "Request " << type << " " << requestArgumentIndex
1217                         << " has number of dimensions (" << requestRank
1218                         << ") different than the model's (" << modelRank << ")";
1219                 for (size_t i = 0; i < requestRank; i++) {
1220                     NN_VALIDATE(modelRank == 0 || operand.dimensions[i] == 0 ||
1221                                 requestArgument.dimensions[i] == operand.dimensions[i])
1222                             << "Request " << type << " " << requestArgumentIndex
1223                             << " has dimension " << i << " of " << requestArgument.dimensions[i]
1224                             << " different than the model's " << operand.dimensions[i];
1225                     if (requestArgument.dimensions[i] == 0) {
1226                         NN_VALIDATE(isOutput && allowUnspecifiedOutput)
1227                                 << "Request " << type << " " << requestArgumentIndex
1228                                 << " has dimension " << i << " of zero";
1229                         // Unspecified output dimensions introduced in Android Q.
1230                         version = combineVersions(version, Version::ANDROID_Q);
1231                     }
1232                 }
1233             }
1234             // NOTE: validateRequestArguments cannot validate DataLocation::length
1235             // with extension operand type.
1236             if (!isExtensionType && requestArgument.location.length != 0) {
1237                 const auto dimensions =
1238                         NN_TRY(combineDimensions(operand.dimensions, requestArgument.dimensions));
1239                 const size_t expectedLength = getNonExtensionSize(operand.type, dimensions).value();
1240                 if (expectedLength != 0) {
1241                     NN_VALIDATE_EQ(requestArgument.location.length, expectedLength)
1242                             << "Request " << type << " " << requestArgumentIndex
1243                             << " expected a size of " << expectedLength << " but got "
1244                             << requestArgument.location.length;
1245                 }
1246             }
1247         }
1248     }
1249     return version;
1250 }
1251 
validateRequestForModelImpl(const Request & request,const Model & model,bool allowUnspecifiedOutput)1252 Result<Version> validateRequestForModelImpl(const Request& request, const Model& model,
1253                                             bool allowUnspecifiedOutput) {
1254     auto version = NN_TRY(validateRequest(request));
1255     version = combineVersions(version, NN_TRY(validateModel(model)));
1256     version = combineVersions(version,
1257                               NN_TRY(validateRequestArgumentsForModel(
1258                                       request.inputs, model.main.inputIndexes, model.main.operands,
1259                                       /*isOutput=*/false, /*allowUnspecifiedOutput=*/true)));
1260     version = combineVersions(
1261             version, NN_TRY(validateRequestArgumentsForModel(
1262                              request.outputs, model.main.outputIndexes, model.main.operands,
1263                              /*isOutput=*/true, allowUnspecifiedOutput)));
1264     return version;
1265 }
1266 
validateMemoryDescImpl(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)1267 Result<Version> validateMemoryDescImpl(
1268         const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
1269         const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
1270         const std::function<const Model*(const SharedPreparedModel&)>& getModel,
1271         std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
1272     NN_VALIDATE(!preparedModels.empty());
1273     NN_VALIDATE(!inputRoles.empty() || !outputRoles.empty());
1274 
1275     std::set<PreparedModelRole> roles;
1276     std::vector<nn::Operand> operands;
1277     operands.reserve(inputRoles.size() + outputRoles.size());
1278     for (const auto& role : inputRoles) {
1279         NN_VALIDATE_LT(role.modelIndex, preparedModels.size());
1280         const auto& preparedModel = preparedModels[role.modelIndex];
1281         NN_VALIDATE(preparedModel != nullptr);
1282         const auto* model = getModel(preparedModel);
1283         NN_VALIDATE(model != nullptr);
1284         const auto& inputIndexes = model->main.inputIndexes;
1285         NN_VALIDATE_LT(role.ioIndex, inputIndexes.size());
1286         NN_VALIDATE_GT(role.probability, 0.0f);
1287         NN_VALIDATE_LE(role.probability, 1.0f);
1288         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
1289         NN_VALIDATE(success);
1290         operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
1291     }
1292     for (const auto& role : outputRoles) {
1293         NN_VALIDATE_LT(role.modelIndex, preparedModels.size());
1294         const auto& preparedModel = preparedModels[role.modelIndex];
1295         NN_VALIDATE(preparedModel != nullptr);
1296         const auto* model = getModel(preparedModel);
1297         NN_VALIDATE(model != nullptr);
1298         const auto& outputIndexes = model->main.outputIndexes;
1299         NN_VALIDATE_LT(role.ioIndex, outputIndexes.size());
1300         NN_VALIDATE_GT(role.probability, 0.0f);
1301         NN_VALIDATE_LE(role.probability, 1.0f);
1302         const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
1303         NN_VALIDATE(success);
1304         operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
1305     }
1306 
1307     CHECK(!operands.empty());
1308     const auto opType = operands.front().type;
1309 
1310     Dimensions dimensions = desc.dimensions;
1311     for (const auto& operand : operands) {
1312         NN_VALIDATE_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type;
1313         NN_VALIDATE_EQ(operand.scale, operands.front().scale);
1314         NN_VALIDATE_EQ(operand.zeroPoint, operands.front().zeroPoint);
1315         // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
1316         if (!isExtension(opType)) {
1317             NN_VALIDATE_EQ(operand.extraParams, operands.front().extraParams)
1318                     << operand.extraParams << " vs " << operands.front().extraParams;
1319         }
1320         dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions));
1321     }
1322 
1323     // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
1324     if (!isExtension(opType)) {
1325         NN_VALIDATE(!isNonExtensionScalar(opType) || dimensions.empty())
1326                 << "invalid dimensions with scalar operand type.";
1327     }
1328 
1329     if (preparedModelRoles != nullptr) {
1330         *preparedModelRoles = std::move(roles);
1331     }
1332     if (combinedOperand != nullptr) {
1333         *combinedOperand = operands.front();
1334         combinedOperand->dimensions = dimensions;
1335     }
1336     return Version::ANDROID_R;
1337 }
1338 
1339 class OperationValidationContext : public IOperationValidationContext {
1340     DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
1341 
1342    public:
OperationValidationContext(const char * operationName,const std::vector<uint32_t> & inputIndexes,const std::vector<uint32_t> & outputIndexes,const std::vector<Operand> & operands)1343     OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes,
1344                                const std::vector<uint32_t>& outputIndexes,
1345                                const std::vector<Operand>& operands)
1346         : operationName(operationName),
1347           inputIndexes(inputIndexes),
1348           outputIndexes(outputIndexes),
1349           operands(operands) {}
1350 
1351     const char* getOperationName() const override;
1352 
1353     uint32_t getNumInputs() const override;
1354     OperandType getInputType(uint32_t index) const override;
1355     Shape getInputShape(uint32_t index) const override;
1356     const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override;
1357 
1358     uint32_t getNumOutputs() const override;
1359     OperandType getOutputType(uint32_t index) const override;
1360     Shape getOutputShape(uint32_t index) const override;
1361 
1362    private:
1363     const Operand* getInputOperand(uint32_t index) const;
1364     const Operand* getOutputOperand(uint32_t index) const;
1365 
1366     const char* operationName;
1367     const std::vector<uint32_t>& inputIndexes;
1368     const std::vector<uint32_t>& outputIndexes;
1369     const std::vector<Operand>& operands;
1370 };
1371 
getOperationName() const1372 const char* OperationValidationContext::getOperationName() const {
1373     return operationName;
1374 }
1375 
getInputOperand(uint32_t index) const1376 const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
1377     return &operands.at(inputIndexes.at(index));
1378 }
1379 
getOutputOperand(uint32_t index) const1380 const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const {
1381     return &operands.at(outputIndexes.at(index));
1382 }
1383 
getNumInputs() const1384 uint32_t OperationValidationContext::getNumInputs() const {
1385     auto count = inputIndexes.size();
1386     CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1387     return static_cast<uint32_t>(count);
1388 }
1389 
getNumOutputs() const1390 uint32_t OperationValidationContext::getNumOutputs() const {
1391     auto count = outputIndexes.size();
1392     CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1393     return static_cast<uint32_t>(count);
1394 }
1395 
getInputType(uint32_t index) const1396 OperandType OperationValidationContext::getInputType(uint32_t index) const {
1397     return getInputOperand(index)->type;
1398 }
1399 
getInputShape(uint32_t index) const1400 Shape OperationValidationContext::getInputShape(uint32_t index) const {
1401     const Operand* operand = getInputOperand(index);
1402     return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1403             operand->extraParams};
1404 }
1405 
getInputExtraParams(uint32_t index) const1406 const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const {
1407     return getInputOperand(index)->extraParams;
1408 }
1409 
getOutputType(uint32_t index) const1410 OperandType OperationValidationContext::getOutputType(uint32_t index) const {
1411     return getOutputOperand(index)->type;
1412 }
1413 
getOutputShape(uint32_t index) const1414 Shape OperationValidationContext::getOutputShape(uint32_t index) const {
1415     const Operand* operand = getOutputOperand(index);
1416     return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1417             operand->extraParams};
1418 }
1419 
1420 // TODO(b/169345292): reduce the duplicate validation here
1421 
validateOperandSymmPerChannelQuantParamsImpl(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)1422 Result<void> validateOperandSymmPerChannelQuantParamsImpl(
1423         const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
1424         const char* tag) {
1425     if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
1426         NN_VALIDATE_FAIL();
1427     }
1428 
1429     NN_VALIDATE_LT(channelQuant.channelDim, operand.dimensions.size()) << tag;
1430     NN_VALIDATE(!channelQuant.scales.empty()) << tag;
1431     NN_VALIDATE_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag;
1432     NN_VALIDATE_NE(operand.dimensions[channelQuant.channelDim], 0u)
1433             << tag << " channel dimension " << channelQuant.channelDim << " is underspecified";
1434     for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) {
1435         NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]";
1436     }
1437     return {};
1438 }
1439 
validateScalarDimensions(const Operand & type,const char * tag)1440 Result<void> validateScalarDimensions(const Operand& type, const char* tag) {
1441     NN_VALIDATE(type.dimensions.empty()) << tag << " invalid dimensions for scalar type";
1442     return {};
1443 }
1444 
validateQuant8AsymmParams(const Operand & type,const char * tag)1445 Result<void> validateQuant8AsymmParams(const Operand& type, const char* tag) {
1446     NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 255)
1447             << tag << " invalid zeroPoint: " << type.zeroPoint;
1448     NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1449     return {};
1450 }
1451 
validateQuant8AsymmSignedParams(const Operand & type,const char * tag)1452 Result<void> validateQuant8AsymmSignedParams(const Operand& type, const char* tag) {
1453     NN_VALIDATE(-128 <= type.zeroPoint && type.zeroPoint <= 127)
1454             << tag << " invalid zeroPoint: " << type.zeroPoint;
1455     NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1456     return {};
1457 }
1458 
validateQuant8SymmParams(const Operand & type,const char * tag)1459 Result<void> validateQuant8SymmParams(const Operand& type, const char* tag) {
1460     NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint;
1461     NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1462     return {};
1463 }
1464 
validateQuant16AsymmParams(const Operand & type,const char * tag)1465 Result<void> validateQuant16AsymmParams(const Operand& type, const char* tag) {
1466     NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 65535)
1467             << tag << " invalid zeroPoint: " << type.zeroPoint;
1468     NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1469     return {};
1470 }
1471 
validateQuantSymmParams(const Operand & type,const char * tag)1472 Result<void> validateQuantSymmParams(const Operand& type, const char* tag) {
1473     NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1474     NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1475     return {};
1476 }
1477 
validateNoQuantParams(const Operand & type,const char * tag)1478 Result<void> validateNoQuantParams(const Operand& type, const char* tag) {
1479     NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1480     NN_VALIDATE_EQ(type.scale, 0.0f) << tag << " scale is not zero";
1481     return {};
1482 }
1483 
validateTensorDimensions(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)1484 Result<void> validateTensorDimensions(
1485         const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo,
1486         const char* tag, bool allowPartial) {
1487     if (!allowPartial) {
1488         NN_VALIDATE(!type.dimensions.empty()) << tag << " invalid operand dimensions";
1489     }
1490     uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize
1491                                            : getNonExtensionSize(type.type);
1492     constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max();
1493     for (size_t i = 0; i < type.dimensions.size(); i++) {
1494         if (!allowPartial) {
1495             NN_VALIDATE_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions";
1496         }
1497         if (type.dimensions[i] != 0) {
1498             size *= type.dimensions[i];
1499             NN_VALIDATE_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize;
1500         }
1501     }
1502     return {};
1503 }
1504 
validateOperandTypeImpl(const Operand & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)1505 Result<void> validateOperandTypeImpl(
1506         const Operand& type,
1507         const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
1508         bool allowPartial) {
1509     if (isExtension(type.type)) {
1510         NN_VALIDATE(extensionOperandTypeInfo != nullptr);
1511         if (extensionOperandTypeInfo->isTensor) {
1512             NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1513         } else {
1514             NN_TRY(validateScalarDimensions(type, tag));
1515         }
1516         return validateNoQuantParams(type, tag);
1517     }
1518 
1519     NN_VALIDATE(extensionOperandTypeInfo == nullptr);
1520     NN_TRY(validateOperandType(type.type));
1521 
1522     if (isNonExtensionScalar(type.type)) {
1523         NN_TRY(validateScalarDimensions(type, tag));
1524         if (type.type != OperandType::OEM) {  // Historically, we have allowed OEM types
1525                                               // to use quantization parameters.
1526             NN_TRY(validateNoQuantParams(type, tag));
1527         }
1528     } else {
1529         NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1530         if (type.type == OperandType::TENSOR_QUANT8_ASYMM) {
1531             NN_TRY(validateQuant8AsymmParams(type, tag));
1532         } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1533             NN_TRY(validateQuant8AsymmSignedParams(type, tag));
1534         } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) {
1535             NN_TRY(validateQuant8SymmParams(type, tag));
1536         } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) {
1537             NN_TRY(validateQuant16AsymmParams(type, tag));
1538         } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) {
1539             NN_TRY(validateQuantSymmParams(type, tag));
1540         } else if (type.type == OperandType::TENSOR_INT32 ||
1541                    type.type == OperandType::TENSOR_OEM_BYTE) {
1542             // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters.
1543             // Historically, we have allowed OEM types to use quantization parameters.
1544         } else {
1545             NN_TRY(validateNoQuantParams(type, tag));
1546         }
1547     }
1548 
1549     return {};
1550 }
1551 
validateOperandListImpl(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)1552 Result<void> validateOperandListImpl(const std::vector<uint32_t>& list, size_t operandCount,
1553                                      const char* tag) {
1554     for (size_t i = 0; i < list.size(); i++) {
1555         NN_VALIDATE_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = "
1556                                               << list[i] << ", operandCount " << operandCount;
1557     }
1558     return {};
1559 }
1560 
validateOperationOperandTypes(const std::vector<Operand> & operands,const std::vector<uint32_t> & inputIndexes,const std::vector<OperandType> & inExpectedTypes,const std::vector<uint32_t> & outputIndexes,const std::vector<OperandType> & outExpectedInTypes)1561 Result<void> validateOperationOperandTypes(const std::vector<Operand>& operands,
1562                                            const std::vector<uint32_t>& inputIndexes,
1563                                            const std::vector<OperandType>& inExpectedTypes,
1564                                            const std::vector<uint32_t>& outputIndexes,
1565                                            const std::vector<OperandType>& outExpectedInTypes) {
1566     NN_VALIDATE_EQ(inputIndexes.size(), inExpectedTypes.size())
1567             << "Wrong operand count: expected " << inputIndexes.size() << " inputs, got "
1568             << inputIndexes.size() << " inputs";
1569     NN_VALIDATE_EQ(outputIndexes.size(), outExpectedInTypes.size())
1570             << "Wrong operand count: expected " << outputIndexes.size() << " outputs, got "
1571             << outputIndexes.size() << " outputs";
1572     for (size_t i = 0; i < inputIndexes.size(); i++) {
1573         NN_VALIDATE_EQ(operands[inputIndexes[i]].type, inExpectedTypes[i])
1574                 << "Invalid input tensor type " << operands[inputIndexes[i]].type << " for input "
1575                 << i << ", expected " << inExpectedTypes[i];
1576     }
1577     for (size_t i = 0; i < outputIndexes.size(); i++) {
1578         NN_VALIDATE_EQ(operands[outputIndexes[i]].type, outExpectedInTypes[i])
1579                 << "Invalid output tensor type " << operands[outputIndexes[i]].type << " for input "
1580                 << i << ", expected " << outExpectedInTypes[i];
1581     }
1582 
1583     return {};
1584 }
1585 
validateSubgraphReference(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1586 Result<void> validateSubgraphReference(const std::vector<Model::Subgraph>& subgraphs,
1587                                        const Operand& modelOperand) {
1588     NN_VALIDATE_EQ(modelOperand.type, OperandType::SUBGRAPH)
1589             << "Unexpected operand type: " << modelOperand.type;
1590     NN_VALIDATE_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference";
1591     return {};
1592 }
getSubgraph(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1593 const Model::Subgraph& getSubgraph(const std::vector<Model::Subgraph>& subgraphs,
1594                                    const Operand& modelOperand) {
1595     return subgraphs.at(modelOperand.location.offset);
1596 }
getInputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1597 uint32_t getInputCount(const std::vector<Model::Subgraph>& subgraphs, const Operand& modelOperand) {
1598     return getSubgraph(subgraphs, modelOperand).inputIndexes.size();
1599 }
getOutputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1600 uint32_t getOutputCount(const std::vector<Model::Subgraph>& subgraphs,
1601                         const Operand& modelOperand) {
1602     return getSubgraph(subgraphs, modelOperand).outputIndexes.size();
1603 }
getInputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1604 const Operand& getInputOperand(const std::vector<Model::Subgraph>& subgraphs,
1605                                const Operand& modelOperand, uint32_t index) {
1606     const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1607     return subgraph.operands.at(subgraph.inputIndexes.at(index));
1608 }
getOutputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1609 const Operand& getOutputOperand(const std::vector<Model::Subgraph>& subgraphs,
1610                                 const Operand& modelOperand, uint32_t index) {
1611     const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1612     return subgraph.operands.at(subgraph.outputIndexes.at(index));
1613 }
1614 
1615 // Checks if two operands have the same types, ranks (if specified), dimensions
1616 // (if specified), scales, zeroPoints, and extraParams.
compatible(const Operand & a,const Operand & b)1617 Result<void> compatible(const Operand& a, const Operand& b) {
1618     NN_VALIDATE_EQ(a.type, b.type) << a.type << " != " << b.type;
1619     if (!a.dimensions.empty() && !b.dimensions.empty()) {
1620         NN_VALIDATE_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions";
1621         for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) {
1622             if (a.dimensions[i] != 0 && b.dimensions[i] != 0) {
1623                 NN_VALIDATE_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions";
1624             }
1625         }
1626     }
1627     NN_VALIDATE_EQ(a.scale, b.scale);
1628     NN_VALIDATE_EQ(a.zeroPoint, b.zeroPoint);
1629     NN_VALIDATE_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams;
1630     return {};
1631 }
1632 
validateConditionOperand(const Operand & operand)1633 Result<void> validateConditionOperand(const Operand& operand) {
1634     NN_VALIDATE_EQ(operand.type, OperandType::TENSOR_BOOL8)
1635             << "Unexpected condition operand type: " << operand.type;
1636     NN_VALIDATE_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
1637     NN_VALIDATE_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
1638     return {};
1639 }
1640 
validateIfOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1641 Result<Version> validateIfOperation(const std::vector<uint32_t>& inputs,
1642                                     const std::vector<uint32_t>& outputs,
1643                                     const std::vector<Operand>& operands,
1644                                     const std::vector<Model::Subgraph>& subgraphs) {
1645     namespace op = operation_if;
1646     NN_VALIDATE_GE(inputs.size(), 3u) << "IF must have at least 3 inputs";
1647     NN_VALIDATE_GE(outputs.size(), 1u) << "IF must have at least 1 output";
1648     auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result<void> {
1649         auto result = validateSubgraphReference(subgraphs, branchModelOperand);
1650         if (!result.has_value()) {
1651             return error() << std::move(result).error()
1652                            << " -- Operand is not a valid subgraph reference";
1653         }
1654         const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand);
1655         const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand);
1656         NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + branchModelInputCount);
1657         NN_VALIDATE_EQ(outputs.size(), branchModelOutputCount);
1658         for (uint32_t i = 0; i < branchModelInputCount; ++i) {
1659             const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i);
1660             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1661             NN_TRY(compatible(innerOperand, outerOperand));
1662         }
1663         for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
1664             const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i);
1665             const Operand& outerOperand = operands[outputs[i]];
1666             NN_TRY(compatible(innerOperand, outerOperand));
1667         }
1668         return {};
1669     };
1670     auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]);
1671     if (!result.has_value()) {
1672         return error() << std::move(result).error() << " for IF condition operand";
1673     }
1674     result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]);
1675     if (!result.has_value()) {
1676         return error() << std::move(result).error() << " for IF then model";
1677     }
1678     result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]);
1679     if (!result.has_value()) {
1680         return error() << std::move(result).error() << " for IF else model";
1681     }
1682     return Version::ANDROID_R;
1683 }
1684 
validateControlFlowOperandUnknownSize(const Operand & operand)1685 Result<Version> validateControlFlowOperandUnknownSize(const Operand& operand) {
1686     if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) {
1687         // 1.3 HAL (corresponding to Version::ANDROID_R) does not support CF operations with
1688         // operands of unknown size. See http://b/132458982#comment63.
1689         return Version::CURRENT_RUNTIME;
1690     }
1691     return Version::ANDROID_R;
1692 }
1693 
validateWhileOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1694 Result<Version> validateWhileOperation(const std::vector<uint32_t>& inputs,
1695                                        const std::vector<uint32_t>& outputs,
1696                                        const std::vector<Operand>& operands,
1697                                        const std::vector<Model::Subgraph>& subgraphs) {
1698     // Let the loop have
1699     // - m >= 1 input-output operands,
1700     // - k >= 0 state-only operands, and
1701     // - n >= 0 input-only operands.
1702     // Then
1703     // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
1704     // - the condition model has (m + k + n) inputs and 1 output.
1705     // - the body model has (m + k + n) inputs and (m + k) outputs.
1706     namespace op = operation_while;
1707     NN_VALIDATE_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs";
1708     NN_VALIDATE_GE(outputs.size(), 1u) << "WHILE must have at least 1 output";
1709     auto validateCondOperand = [&](const Operand& condModelOperand) -> Result<Version> {
1710         Version version = Version::ANDROID_R;
1711         auto result = validateSubgraphReference(subgraphs, condModelOperand);
1712         if (!result.has_value()) {
1713             return error() << std::move(result).error()
1714                            << " -- Operand is not a valid subgraph reference";
1715         }
1716         const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand);
1717         const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand);
1718         NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + condModelInputCount);
1719         NN_VALIDATE_EQ(condModelOutputCount, 1u);
1720         for (uint32_t i = 0; i < condModelInputCount; ++i) {
1721             const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i);
1722             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1723             NN_TRY(compatible(innerOperand, outerOperand));
1724             version = combineVersions(version,
1725                                       NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1726             version = combineVersions(version,
1727                                       NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1728         }
1729         NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0)));
1730         return version;
1731     };
1732     auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result<Version> {
1733         Version version = Version::ANDROID_R;
1734         auto result = validateSubgraphReference(subgraphs, bodyModelOperand);
1735         if (!result.has_value()) {
1736             return error() << std::move(result).error()
1737                            << " -- Operand is not a valid subgraph reference";
1738         }
1739         const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand);
1740         const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand);
1741         NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount);
1742         NN_VALIDATE_GE(bodyModelOutputCount, outputs.size());
1743         NN_VALIDATE_GE(bodyModelInputCount, bodyModelOutputCount);
1744         const uint32_t inputOutputCount = outputs.size();
1745         const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
1746         const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
1747         for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
1748             const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1749             const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1750             NN_TRY(compatible(innerOperand, outerOperand));
1751             version = combineVersions(version,
1752                                       NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1753             version = combineVersions(version,
1754                                       NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1755         }
1756         for (uint32_t i = 0; i < inputOutputCount; ++i) {
1757             const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1758             const Operand& outerOperand = operands[outputs[i]];
1759             NN_TRY(compatible(innerOperand, outerOperand));
1760             version = combineVersions(version,
1761                                       NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1762         }
1763         for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
1764             const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1765             const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1766             NN_TRY(compatible(inputOperand, outputOperand));
1767             version = combineVersions(version,
1768                                       NN_TRY(validateControlFlowOperandUnknownSize(outputOperand)));
1769         }
1770         return version;
1771     };
1772     auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]);
1773     if (!result.has_value()) {
1774         return error() << std::move(result).error() << " for WHILE condition model";
1775     }
1776     auto version = result.value();
1777     result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]);
1778     if (!result.has_value()) {
1779         return error() << std::move(result).error() << " for WHILE body model";
1780     }
1781     version = combineVersions(version, result.value());
1782     return version;
1783 }
1784 
validateOperationButNotOperandsImpl(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1785 Result<Version> validateOperationButNotOperandsImpl(const Operation& operation,
1786                                                     const std::vector<Operand>& operands,
1787                                                     const std::vector<Model::Subgraph>& subgraphs) {
1788     const auto opType = operation.type;
1789     const auto& inputIndexes = operation.inputs;
1790     const auto& outputIndexes = operation.outputs;
1791 
1792     NN_TRY(validateOperandListImpl(inputIndexes, operands.size(),
1793                                    "ANeuralNetworksModel_addOperation inputs"));
1794     NN_TRY(validateOperandListImpl(outputIndexes, operands.size(),
1795                                    "ANeuralNetworksModel_addOperation outputs"));
1796 
1797     if (isExtension(opType)) {
1798         // There is no other validation we can do for an extension operation.
1799         return Version::ANDROID_Q;
1800     }
1801 
1802     auto invalidInOutNumberMessage = [opType, &inputIndexes, &outputIndexes](int expIn,
1803                                                                              int expOut) {
1804         std::ostringstream os;
1805         os << "Invalid number of input operands (" << inputIndexes.size() << ", expected " << expIn
1806            << ") or output operands (" << outputIndexes.size() << ", expected " << expOut
1807            << ") for operation " << opType;
1808         return os.str();
1809     };
1810 
1811     switch (opType) {
1812         case OperationType::OEM_OPERATION: {
1813             return Version::ANDROID_OC_MR1;
1814         }
1815         case OperationType::RESHAPE: {
1816             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
1817                     << invalidInOutNumberMessage(2, 1);
1818             auto inputType = operands[inputIndexes[0]].type;
1819             Version version;
1820             std::vector<OperandType> inExpectedTypes;
1821             std::vector<OperandType> outExpectedTypes;
1822             if (inputType == OperandType::TENSOR_FLOAT32) {
1823                 version = Version::ANDROID_OC_MR1;
1824                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32};
1825                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1826             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1827                 version = Version::ANDROID_Q;
1828                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32};
1829                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1830             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1831                 version = Version::ANDROID_OC_MR1;
1832                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32};
1833                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1834             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1835                 version = Version::ANDROID_R;
1836                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
1837                                    OperandType::TENSOR_INT32};
1838                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1839             } else {
1840                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
1841             }
1842             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1843             NN_VALIDATE_LE(inputRank, 4u)
1844                     << "Unsupported input tensor rank for operation " << opType;
1845             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1846                                                  outputIndexes, outExpectedTypes));
1847             return version;
1848         }
1849         case OperationType::DEPTH_TO_SPACE: {
1850             NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
1851                         outputIndexes.size() == 1)
1852                     << "Invalid number of input operands (" << inputIndexes.size()
1853                     << ", expected 3 or 2) or output operands (" << outputIndexes.size()
1854                     << ", expected 1) for operation " << opType;
1855             auto inputType = operands[inputIndexes[0]].type;
1856             Version version;
1857             std::vector<OperandType> inExpectedTypes;
1858             std::vector<OperandType> outExpectedTypes;
1859             if (inputType == OperandType::TENSOR_FLOAT32) {
1860                 version = Version::ANDROID_OC_MR1;
1861                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
1862                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1863             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1864                 version = Version::ANDROID_Q;
1865                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
1866                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1867             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1868                 version = Version::ANDROID_OC_MR1;
1869                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
1870                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1871             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1872                 version = Version::ANDROID_R;
1873                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
1874                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1875             } else {
1876                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
1877             }
1878             if (inputIndexes.size() == 3) {
1879                 inExpectedTypes.push_back(OperandType::BOOL);
1880                 version = combineVersions(version, Version::ANDROID_Q);
1881             } else {
1882                 version = combineVersions(version, Version::ANDROID_OC_MR1);
1883             }
1884             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1885                                                  outputIndexes, outExpectedTypes));
1886             return version;
1887         }
1888         case OperationType::SPACE_TO_DEPTH: {
1889             NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
1890                         outputIndexes.size() == 1)
1891                     << "Invalid number of input operands (" << inputIndexes.size()
1892                     << ", expected 3 or 2) or output operands (" << outputIndexes.size()
1893                     << ", expected 1) for operation " << opType;
1894             auto inputType = operands[inputIndexes[0]].type;
1895             Version version;
1896             std::vector<OperandType> inExpectedTypes;
1897             std::vector<OperandType> outExpectedTypes;
1898             if (inputType == OperandType::TENSOR_FLOAT32) {
1899                 version = Version::ANDROID_OC_MR1;
1900                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
1901                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1902             } else if (inputType == OperandType::TENSOR_FLOAT16) {
1903                 version = Version::ANDROID_Q;
1904                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
1905                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1906             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1907                 version = Version::ANDROID_OC_MR1;
1908                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
1909                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1910             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1911                 version = Version::ANDROID_R;
1912                 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
1913                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1914             } else {
1915                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
1916             }
1917             if (inputIndexes.size() == 3) {
1918                 inExpectedTypes.push_back(OperandType::BOOL);
1919                 version = combineVersions(version, Version::ANDROID_Q);
1920             } else {
1921                 version = combineVersions(version, Version::ANDROID_OC_MR1);
1922             }
1923             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1924                                                  outputIndexes, outExpectedTypes));
1925             return version;
1926         }
1927         case OperationType::EMBEDDING_LOOKUP: {
1928             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
1929                     << invalidInOutNumberMessage(2, 1);
1930             auto inputType = operands[inputIndexes[1]].type;
1931             NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
1932                         inputType == OperandType::TENSOR_FLOAT32 ||
1933                         inputType == OperandType::TENSOR_INT32 ||
1934                         inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1935                         inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
1936                     << "Unsupported input tensor type for operation " << opType;
1937             Version version;
1938             std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType};
1939             std::vector<OperandType> outExpectedTypes = {inputType};
1940             if (inputType == OperandType::TENSOR_FLOAT16 ||
1941                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1942                 version = Version::ANDROID_R;
1943             } else if (inputType == OperandType::TENSOR_INT32 ||
1944                        inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1945                 version = Version::ANDROID_Q;
1946             } else {
1947                 version = Version::ANDROID_OC_MR1;
1948             }
1949             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1950                                                  outputIndexes, outExpectedTypes));
1951             return version;
1952         }
1953         case OperationType::HASHTABLE_LOOKUP: {
1954             NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 2)
1955                     << invalidInOutNumberMessage(3, 2);
1956             auto inputType = operands[inputIndexes[2]].type;
1957             NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
1958                         inputType == OperandType::TENSOR_INT32 ||
1959                         inputType == OperandType::TENSOR_QUANT8_ASYMM)
1960                     << "Unsupported input tensor type for operation " << opType;
1961             std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32,
1962                                                         OperandType::TENSOR_INT32, inputType};
1963             std::vector<OperandType> outExpectedTypes = {inputType,
1964                                                          OperandType::TENSOR_QUANT8_ASYMM};
1965             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1966                                                  outputIndexes, outExpectedTypes));
1967             return Version::ANDROID_OC_MR1;
1968         }
1969         case OperationType::LSH_PROJECTION: {
1970             NN_VALIDATE(inputIndexes.size() == 4 && outputIndexes.size() == 1)
1971                     << invalidInOutNumberMessage(4, 1);
1972             auto inputType = operands[inputIndexes[1]].type;
1973             NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
1974                         inputType == OperandType::TENSOR_FLOAT32 ||
1975                         inputType == OperandType::TENSOR_INT32 ||
1976                         inputType == OperandType::TENSOR_QUANT8_ASYMM)
1977                     << "Unsupported input tensor type for operation " << opType;
1978             auto hashType = operands[inputIndexes[0]].type;
1979             Version version;
1980             std::vector<OperandType> inExpectedTypes;
1981             if (hashType == OperandType::TENSOR_FLOAT16) {
1982                 version = Version::ANDROID_Q;
1983                 inExpectedTypes = {
1984                         OperandType::TENSOR_FLOAT16,
1985                         inputType,
1986                         OperandType::TENSOR_FLOAT16,
1987                         OperandType::INT32,
1988                 };
1989             } else if (hashType == OperandType::TENSOR_FLOAT32) {
1990                 version = Version::ANDROID_OC_MR1;
1991                 inExpectedTypes = {
1992                         OperandType::TENSOR_FLOAT32,
1993                         inputType,
1994                         OperandType::TENSOR_FLOAT32,
1995                         OperandType::INT32,
1996                 };
1997             } else {
1998                 NN_VALIDATE_FAIL() << "Unsupported hash tensor type for operation " << opType;
1999             }
2000             std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
2001             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2002                                                  outputIndexes, outExpectedTypes));
2003             return version;
2004         }
2005         case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM: {
2006             const uint32_t kNumOutputs = 2;
2007             const uint32_t kNumOutputsMerged = 1;
2008             const uint32_t kNumOutputsWithState = 6;
2009             const uint32_t kNumOutputsMergedWithState = 5;
2010             NN_VALIDATE(inputIndexes.size() == 61 &&
2011                         (outputIndexes.size() == kNumOutputs ||
2012                          outputIndexes.size() == kNumOutputsMerged ||
2013                          outputIndexes.size() == kNumOutputsWithState ||
2014                          outputIndexes.size() == kNumOutputsMergedWithState))
2015                     << "Invalid number of input operands (" << inputIndexes.size()
2016                     << ", expected 61) or output operands (" << outputIndexes.size()
2017                     << ", expected 1, 2, 5 or 6) for operation " << opType;
2018 
2019             std::vector<OperandType> inExpectedTypes;
2020             auto inputType = operands[inputIndexes[0]].type;
2021             NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
2022                         inputType == OperandType::TENSOR_FLOAT16)
2023                     << "Unsupported input tensor type for operation " << opType;
2024 
2025             inExpectedTypes = {};
2026             for (int i = 0; i < 48; ++i) {
2027                 inExpectedTypes.push_back(inputType);
2028             }
2029             inExpectedTypes.push_back(OperandType::INT32);
2030             inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
2031                                               ? OperandType::FLOAT32
2032                                               : OperandType::FLOAT16);
2033             inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
2034                                               ? OperandType::FLOAT32
2035                                               : OperandType::FLOAT16);
2036             inExpectedTypes.push_back(OperandType::BOOL);
2037             inExpectedTypes.push_back(OperandType::BOOL);
2038             for (int i = 0; i < 8; ++i) {
2039                 inExpectedTypes.push_back(inputType);
2040             }
2041 
2042             Version version = Version::ANDROID_Q;
2043             if (outputIndexes.size() == kNumOutputsWithState ||
2044                 outputIndexes.size() == kNumOutputsMergedWithState) {
2045                 version = Version::ANDROID_R;
2046             }
2047             std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType);
2048             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2049                                                  outputIndexes, outExpectedTypes));
2050             return version;
2051         }
2052         case OperationType::LSTM: {
2053             NN_VALIDATE((inputIndexes.size() == 23 || inputIndexes.size() == 27) &&
2054                         outputIndexes.size() == 4)
2055                     << "Invalid number of input operands (" << inputIndexes.size()
2056                     << ", expected 23 or 27) or output operands (" << outputIndexes.size()
2057                     << ", expected 4) for operation " << opType;
2058             std::vector<OperandType> inExpectedTypes;
2059             std::vector<OperandType> outExpectedTypes;
2060             auto inputType = operands[inputIndexes[0]].type;
2061             NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
2062                         inputType == OperandType::TENSOR_FLOAT16)
2063                     << "Unsupported input tensor type for operation " << opType;
2064 
2065             Version version = Version::ANDROID_OC_MR1;
2066             inExpectedTypes = {inputType,         inputType, inputType, inputType, inputType,
2067                                inputType,         inputType, inputType, inputType, inputType,
2068                                inputType,         inputType, inputType, inputType, inputType,
2069                                inputType,         inputType, inputType, inputType, inputType,
2070                                OperandType::INT32};
2071             if (inputType == OperandType::TENSOR_FLOAT32) {
2072                 inExpectedTypes.push_back(OperandType::FLOAT32);
2073                 inExpectedTypes.push_back(OperandType::FLOAT32);
2074             } else {
2075                 version = Version::ANDROID_Q;
2076                 inExpectedTypes.push_back(OperandType::FLOAT16);
2077                 inExpectedTypes.push_back(OperandType::FLOAT16);
2078             }
2079 
2080             outExpectedTypes = {inputType, inputType, inputType, inputType};
2081             if (inputIndexes.size() == 23) {
2082                 version = combineVersions(version, Version::ANDROID_OC_MR1);
2083             } else {
2084                 version = combineVersions(version, Version::ANDROID_Q);
2085                 for (int i = 0; i < 4; ++i) {
2086                     inExpectedTypes.push_back(inputType);
2087                 }
2088             }
2089             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2090                                                  outputIndexes, outExpectedTypes));
2091             return version;
2092         }
2093         case OperationType::QUANTIZED_16BIT_LSTM: {
2094             NN_VALIDATE(inputIndexes.size() == 15 && outputIndexes.size() == 2)
2095                     << invalidInOutNumberMessage(15, 2);
2096             std::vector<OperandType> inExpectedTypes = {
2097                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2098                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2099                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2100                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2101                     OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32,
2102                     OperandType::TENSOR_INT32,        OperandType::TENSOR_INT32,
2103                     OperandType::TENSOR_INT32,        OperandType::TENSOR_QUANT16_SYMM,
2104                     OperandType::TENSOR_QUANT8_ASYMM};
2105             std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM,
2106                                                          OperandType::TENSOR_QUANT8_ASYMM};
2107             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2108                                                  outputIndexes, outExpectedTypes));
2109             return Version::ANDROID_Q;
2110         }
2111         case OperationType::RANDOM_MULTINOMIAL: {
2112             NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
2113                     << invalidInOutNumberMessage(3, 1);
2114             OperandType inputType = operands[inputIndexes[0]].type;
2115             std::vector<OperandType> inExpectedTypes;
2116             if (inputType == OperandType::TENSOR_FLOAT32 ||
2117                 inputType == OperandType::TENSOR_FLOAT16) {
2118                 inExpectedTypes = {inputType, OperandType::INT32, OperandType::TENSOR_INT32};
2119             } else {
2120                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2121             }
2122             std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
2123             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2124                                                  outputIndexes, outExpectedTypes));
2125             return Version::ANDROID_Q;
2126         }
2127         case OperationType::RNN: {
2128             NN_VALIDATE(inputIndexes.size() == 6 && outputIndexes.size() == 2)
2129                     << invalidInOutNumberMessage(6, 2);
2130             OperandType inputType = operands[inputIndexes[0]].type;
2131             Version version;
2132             std::vector<OperandType> inExpectedTypes;
2133             std::vector<OperandType> outExpectedTypes;
2134             if (inputType == OperandType::TENSOR_FLOAT32) {
2135                 version = Version::ANDROID_OC_MR1;
2136                 inExpectedTypes = {
2137                         OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
2138                         OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
2139                         OperandType::TENSOR_FLOAT32, OperandType::INT32,
2140                 };
2141                 outExpectedTypes = {
2142                         OperandType::TENSOR_FLOAT32,
2143                         OperandType::TENSOR_FLOAT32,
2144                 };
2145             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2146                 version = Version::ANDROID_Q;
2147                 inExpectedTypes = {
2148                         OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
2149                         OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
2150                         OperandType::TENSOR_FLOAT16, OperandType::INT32,
2151                 };
2152                 outExpectedTypes = {
2153                         OperandType::TENSOR_FLOAT16,
2154                         OperandType::TENSOR_FLOAT16,
2155                 };
2156             } else {
2157                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2158             }
2159             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2160                                                  outputIndexes, outExpectedTypes));
2161             return version;
2162         }
2163         case OperationType::SVDF: {
2164             NN_VALIDATE(inputIndexes.size() == 7 && outputIndexes.size() == 2)
2165                     << invalidInOutNumberMessage(7, 2);
2166             Version version;
2167             OperandType inputType = operands[inputIndexes[0]].type;
2168             if (inputType == OperandType::TENSOR_FLOAT32) {
2169                 version = Version::ANDROID_OC_MR1;
2170             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2171                 version = Version::ANDROID_Q;
2172             } else {
2173                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2174             }
2175             std::vector<OperandType> inExpectedTypes = {
2176                     inputType, inputType,          inputType,          inputType,
2177                     inputType, OperandType::INT32, OperandType::INT32,
2178             };
2179             std::vector<OperandType> outExpectedTypes = {inputType, inputType};
2180             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2181                                                  outputIndexes, outExpectedTypes));
2182             return version;
2183         }
2184         case OperationType::BATCH_TO_SPACE_ND: {
2185             NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
2186                         outputIndexes.size() == 1)
2187                     << "Invalid number of input operands (" << inputIndexes.size()
2188                     << ", expected 3 or 2) or output operands (" << outputIndexes.size()
2189                     << ", expected 1) for operation " << opType;
2190             auto inputType = operands[inputIndexes[0]].type;
2191             Version version = Version::ANDROID_OC_MR1;
2192             std::vector<OperandType> inExpectedTypes;
2193             std::vector<OperandType> outExpectedTypes;
2194             if (inputType == OperandType::TENSOR_FLOAT32) {
2195                 inExpectedTypes = {
2196                         OperandType::TENSOR_FLOAT32,
2197                         OperandType::TENSOR_INT32,
2198                 };
2199                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2200             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2201                 version = Version::ANDROID_Q;
2202                 inExpectedTypes = {
2203                         OperandType::TENSOR_FLOAT16,
2204                         OperandType::TENSOR_INT32,
2205                 };
2206                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2207             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
2208                 inExpectedTypes = {
2209                         OperandType::TENSOR_QUANT8_ASYMM,
2210                         OperandType::TENSOR_INT32,
2211                 };
2212                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
2213             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2214                 version = Version::ANDROID_R;
2215                 inExpectedTypes = {
2216                         OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
2217                         OperandType::TENSOR_INT32,
2218                 };
2219                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
2220             } else {
2221                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2222             }
2223             if (inputIndexes.size() == 3) {
2224                 inExpectedTypes.push_back(OperandType::BOOL);
2225                 version = combineVersions(version, Version::ANDROID_Q);
2226             } else {
2227                 version = combineVersions(version, Version::ANDROID_P);
2228             }
2229             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2230                                                  outputIndexes, outExpectedTypes));
2231             return version;
2232         }
2233         case OperationType::SPACE_TO_BATCH_ND: {
2234             NN_VALIDATE((inputIndexes.size() == 4 || inputIndexes.size() == 3) &&
2235                         outputIndexes.size() == 1)
2236                     << "Invalid number of input operands (" << inputIndexes.size()
2237                     << ", expected 4 or 3) or output operands (" << outputIndexes.size()
2238                     << ", expected 1) for operation " << opType;
2239             auto inputType = operands[inputIndexes[0]].type;
2240             Version version = Version::ANDROID_OC_MR1;
2241             std::vector<OperandType> inExpectedTypes;
2242             std::vector<OperandType> outExpectedTypes;
2243             if (inputType == OperandType::TENSOR_FLOAT32) {
2244                 inExpectedTypes = {
2245                         OperandType::TENSOR_FLOAT32,
2246                         OperandType::TENSOR_INT32,
2247                         OperandType::TENSOR_INT32,
2248                 };
2249                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2250             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2251                 version = Version::ANDROID_Q;
2252                 inExpectedTypes = {
2253                         OperandType::TENSOR_FLOAT16,
2254                         OperandType::TENSOR_INT32,
2255                         OperandType::TENSOR_INT32,
2256                 };
2257                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2258             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
2259                 if (operands[inputIndexes[0]].zeroPoint != 0) {
2260                     version = Version::ANDROID_Q;
2261                 }
2262                 inExpectedTypes = {
2263                         OperandType::TENSOR_QUANT8_ASYMM,
2264                         OperandType::TENSOR_INT32,
2265                         OperandType::TENSOR_INT32,
2266                 };
2267                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
2268             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2269                 version = Version::ANDROID_R;
2270                 inExpectedTypes = {
2271                         OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
2272                         OperandType::TENSOR_INT32,
2273                         OperandType::TENSOR_INT32,
2274                 };
2275                 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
2276             } else {
2277                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2278             }
2279             if (inputIndexes.size() == 4) {
2280                 inExpectedTypes.push_back(OperandType::BOOL);
2281                 version = combineVersions(version, Version::ANDROID_Q);
2282             } else {
2283                 version = combineVersions(version, Version::ANDROID_P);
2284             }
2285             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2286                                                  outputIndexes, outExpectedTypes));
2287             return version;
2288         }
2289         case OperationType::PAD: {
2290             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2291                     << invalidInOutNumberMessage(2, 1);
2292             auto inputType = operands[inputIndexes[0]].type;
2293             Version version;
2294             std::vector<OperandType> inExpectedTypes;
2295             std::vector<OperandType> outExpectedTypes;
2296             if (inputType == OperandType::TENSOR_FLOAT32) {
2297                 version = Version::ANDROID_P;
2298                 inExpectedTypes = {
2299                         OperandType::TENSOR_FLOAT32,
2300                         OperandType::TENSOR_INT32,
2301                 };
2302                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2303             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2304                 version = Version::ANDROID_Q;
2305                 inExpectedTypes = {
2306                         OperandType::TENSOR_FLOAT16,
2307                         OperandType::TENSOR_INT32,
2308                 };
2309                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2310             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2311                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2312                 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2313                     version = Version::ANDROID_R;
2314                 } else {
2315                     if (operands[inputIndexes[0]].zeroPoint == 0) {
2316                         version = Version::ANDROID_P;
2317                     } else {
2318                         version = Version::ANDROID_Q;
2319                     }
2320                 }
2321                 inExpectedTypes = {
2322                         inputType,
2323                         OperandType::TENSOR_INT32,
2324                 };
2325                 outExpectedTypes = {inputType};
2326             } else {
2327                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2328             }
2329             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
2330             NN_VALIDATE_LE(inputRank, 4u)
2331                     << "Unsupported input tensor rank for operation " << opType;
2332             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2333                                                  outputIndexes, outExpectedTypes));
2334             return version;
2335         }
2336         case OperationType::PAD_V2: {
2337             NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
2338                     << invalidInOutNumberMessage(3, 1);
2339             auto inputType = operands[inputIndexes[0]].type;
2340             Version version;
2341             std::vector<OperandType> inExpectedTypes;
2342             std::vector<OperandType> outExpectedTypes;
2343             if (inputType == OperandType::TENSOR_FLOAT32) {
2344                 version = Version::ANDROID_Q;
2345                 inExpectedTypes = {
2346                         OperandType::TENSOR_FLOAT32,
2347                         OperandType::TENSOR_INT32,
2348                         OperandType::FLOAT32,
2349                 };
2350                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2351             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2352                 version = Version::ANDROID_Q;
2353                 inExpectedTypes = {
2354                         OperandType::TENSOR_FLOAT16,
2355                         OperandType::TENSOR_INT32,
2356                         OperandType::FLOAT16,
2357                 };
2358                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2359             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2360                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2361                 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2362                     version = Version::ANDROID_R;
2363                 } else {
2364                     version = Version::ANDROID_Q;
2365                 }
2366                 inExpectedTypes = {
2367                         inputType,
2368                         OperandType::TENSOR_INT32,
2369                         OperandType::INT32,
2370                 };  // TODO(b/116699425): Make it UINT8.
2371                 outExpectedTypes = {inputType};
2372             } else {
2373                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2374             }
2375             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
2376             NN_VALIDATE_LE(inputRank, 4u)
2377                     << "Unsupported input tensor rank for operation " << opType;
2378             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2379                                                  outputIndexes, outExpectedTypes));
2380             return version;
2381         }
2382         case OperationType::CAST: {
2383             NN_VALIDATE(inputIndexes.size() == 1 && outputIndexes.size() == 1)
2384                     << invalidInOutNumberMessage(1, 1);
2385             auto inputOperand = operands[inputIndexes[0]];
2386             auto outputOperand = operands[outputIndexes[0]];
2387             auto inputType = inputOperand.type;
2388             auto outputType = outputOperand.type;
2389             Version version;
2390             std::vector<OperandType> inExpectedTypes;
2391             std::vector<OperandType> outExpectedTypes;
2392             if ((inputType == OperandType::TENSOR_FLOAT16 ||
2393                  inputType == OperandType::TENSOR_FLOAT32 ||
2394                  inputType == OperandType::TENSOR_INT32 ||
2395                  inputType == OperandType::TENSOR_QUANT8_ASYMM) &&
2396                 (outputType == OperandType::TENSOR_FLOAT16 ||
2397                  outputType == OperandType::TENSOR_FLOAT32 ||
2398                  outputType == OperandType::TENSOR_INT32 ||
2399                  outputType == OperandType::TENSOR_QUANT8_ASYMM)) {
2400                 version = Version::ANDROID_Q;
2401                 inExpectedTypes = {inputType};
2402                 outExpectedTypes = {outputType};
2403             } else if (inputType == OperandType::TENSOR_BOOL8 ||
2404                        inputType == OperandType::TENSOR_QUANT16_ASYMM ||
2405                        inputType == OperandType::TENSOR_QUANT16_SYMM ||
2406                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
2407                        inputType == OperandType::TENSOR_QUANT8_SYMM) {
2408                 version = Version::ANDROID_R;
2409                 inExpectedTypes = {inputType};
2410                 outExpectedTypes = {inputType};  // Only identity CAST is supported.
2411             } else {
2412                 NN_VALIDATE_FAIL() << "Unsupported data type for operation " << opType;
2413             }
2414             // Validate that output shape is equal to input shape if dimensions
2415             // are already known.
2416             auto getNumberOfElements = [](const std::vector<uint32_t>& dims) {
2417                 if (dims.empty()) {
2418                     return 0;
2419                 }
2420                 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
2421             };
2422             NN_VALIDATE(inputOperand.dimensions.empty() || outputOperand.dimensions.empty() ||
2423                         getNumberOfElements(outputOperand.dimensions) == 0 ||
2424                         inputOperand.dimensions == outputOperand.dimensions);
2425             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2426                                                  outputIndexes, outExpectedTypes));
2427             return version;
2428         }
2429         case OperationType::MEAN: {
2430             NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
2431                     << invalidInOutNumberMessage(3, 1);
2432             const auto inputRank = operands[inputIndexes[0]].dimensions.size();
2433             NN_VALIDATE_LE(inputRank, 4u)
2434                     << "Unsupported input tensor rank for operation " << opType;
2435             auto inputType = operands[inputIndexes[0]].type;
2436             Version version;
2437             if (inputType == OperandType::TENSOR_FLOAT32 ||
2438                 inputType == OperandType::TENSOR_QUANT8_ASYMM) {
2439                 version = Version::ANDROID_P;
2440             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2441                 version = Version::ANDROID_Q;
2442             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2443                 version = Version::ANDROID_R;
2444             } else {
2445                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2446             }
2447             std::vector<OperandType> inExpectedTypes = {inputType, OperandType::TENSOR_INT32,
2448                                                         OperandType::INT32};
2449             std::vector<OperandType> outExpectedTypes = {inputType};
2450             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2451                                                  outputIndexes, outExpectedTypes));
2452             return version;
2453         }
2454         case OperationType::ARGMAX:
2455         case OperationType::ARGMIN: {
2456             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2457                     << invalidInOutNumberMessage(2, 1);
2458             auto inputType = operands[inputIndexes[0]].type;
2459             std::vector<OperandType> inExpectedTypes;
2460             std::vector<OperandType> outExpectedTypes;
2461             if (inputType == OperandType::TENSOR_FLOAT16 ||
2462                 inputType == OperandType::TENSOR_FLOAT32 ||
2463                 inputType == OperandType::TENSOR_INT32 ||
2464                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2465                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2466                 inExpectedTypes = {inputType, OperandType::INT32};
2467                 outExpectedTypes = {OperandType::TENSOR_INT32};
2468             } else {
2469                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2470             }
2471             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2472                                                  outputIndexes, outExpectedTypes));
2473             return Version::ANDROID_Q;
2474         }
2475         case OperationType::EXPAND_DIMS: {
2476             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2477                     << invalidInOutNumberMessage(2, 1);
2478             auto inputType = operands[inputIndexes[0]].type;
2479             std::vector<OperandType> inExpectedTypes;
2480             std::vector<OperandType> outExpectedTypes;
2481             if (inputType == OperandType::TENSOR_FLOAT16 ||
2482                 inputType == OperandType::TENSOR_FLOAT32 ||
2483                 inputType == OperandType::TENSOR_INT32 ||
2484                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2485                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2486                 inExpectedTypes = {inputType, OperandType::INT32};
2487                 outExpectedTypes = {inputType};
2488             } else {
2489                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2490             }
2491             Version version;
2492             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2493                 version = Version::ANDROID_R;
2494             } else {
2495                 version = Version::ANDROID_Q;
2496             }
2497             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2498                                                  outputIndexes, outExpectedTypes));
2499             return version;
2500         }
2501         case OperationType::SPLIT: {
2502             NN_VALIDATE_EQ(inputIndexes.size(), 3u)
2503                     << "Invalid number of input operands (" << inputIndexes.size()
2504                     << ", expected 3)" << opType;
2505             auto inputType = operands[inputIndexes[0]].type;
2506             NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
2507                         inputType == OperandType::TENSOR_FLOAT32 ||
2508                         inputType == OperandType::TENSOR_INT32 ||
2509                         inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2510                         inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
2511                     << "Unsupported input tensor type for operation " << opType;
2512             Version version;
2513             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2514                 version = Version::ANDROID_R;
2515             } else {
2516                 version = Version::ANDROID_Q;
2517             }
2518             std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32,
2519                                                         OperandType::INT32};
2520             std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType);
2521             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2522                                                  outputIndexes, outExpectedTypes));
2523             return version;
2524         }
2525         case OperationType::MAXIMUM:
2526         case OperationType::MINIMUM: {
2527             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2528                     << invalidInOutNumberMessage(2, 1);
2529             std::vector<OperandType> inExpectedTypes;
2530             std::vector<OperandType> outExpectedTypes;
2531             OperandType inputType = operands[inputIndexes[0]].type;
2532             if (inputType == OperandType::TENSOR_FLOAT16 ||
2533                 inputType == OperandType::TENSOR_FLOAT32 ||
2534                 inputType == OperandType::TENSOR_INT32 ||
2535                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2536                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2537                 inExpectedTypes = {inputType, inputType};
2538                 outExpectedTypes = {inputType};
2539             } else {
2540                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2541             }
2542             Version version;
2543             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2544                 version = Version::ANDROID_R;
2545             } else {
2546                 version = Version::ANDROID_Q;
2547             }
2548             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2549                                                  outputIndexes, outExpectedTypes));
2550             return version;
2551         }
2552         case OperationType::GROUPED_CONV_2D: {
2553             NN_VALIDATE((inputIndexes.size() == 12 || inputIndexes.size() == 9) &&
2554                         outputIndexes.size() == 1)
2555                     << "Invalid number of input operands (" << inputIndexes.size()
2556                     << ", expected 12 or 9) or output operands (" << outputIndexes.size()
2557                     << ", expected 1) for operation " << opType;
2558             auto inputType = operands[inputIndexes[0]].type;
2559             auto filterType = operands[inputIndexes[1]].type;
2560             std::vector<OperandType> inExpectedTypes;
2561             std::vector<OperandType> outExpectedTypes;
2562             if (inputType == OperandType::TENSOR_FLOAT32) {
2563                 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
2564                                    OperandType::TENSOR_FLOAT32, OperandType::INT32,
2565                                    OperandType::INT32,          OperandType::INT32,
2566                                    OperandType::INT32,          OperandType::INT32};
2567                 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2568             } else if (inputType == OperandType::TENSOR_FLOAT16) {
2569                 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
2570                                    OperandType::TENSOR_FLOAT16, OperandType::INT32,
2571                                    OperandType::INT32,          OperandType::INT32,
2572                                    OperandType::INT32,          OperandType::INT32};
2573                 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2574             } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2575                        inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2576                 NN_VALIDATE(filterType == inputType ||
2577                             filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)
2578                         << "Unsupported filter tensor type for operation " << opType;
2579 
2580                 NN_VALIDATE(filterType != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
2581                             std::get<Operand::SymmPerChannelQuantParams>(
2582                                     operands[inputIndexes[1]].extraParams)
2583                                             .channelDim == 0)
2584                         << "Unsupported filter tensor channel dimension for operation " << opType;
2585 
2586                 inExpectedTypes = {
2587                         inputType,          filterType,         OperandType::TENSOR_INT32,
2588                         OperandType::INT32, OperandType::INT32, OperandType::INT32,
2589                         OperandType::INT32, OperandType::INT32};
2590                 outExpectedTypes = {inputType};
2591             } else {
2592                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2593             }
2594 
2595             if (inputIndexes.size() == 12) {
2596                 std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
2597                 inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
2598                                        explicitScalarTypes.end());
2599             }
2600             inExpectedTypes.push_back(OperandType::BOOL);
2601             Version version;
2602             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2603                 version = Version::ANDROID_R;
2604             } else {
2605                 version = Version::ANDROID_Q;
2606             }
2607             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2608                                                  outputIndexes, outExpectedTypes));
2609             return version;
2610         }
2611         case OperationType::TILE: {
2612             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2613                     << invalidInOutNumberMessage(2, 1);
2614             auto inputType = operands[inputIndexes[0]].type;
2615             std::vector<OperandType> inExpectedTypes;
2616             std::vector<OperandType> outExpectedTypes;
2617             if (inputType == OperandType::TENSOR_FLOAT16 ||
2618                 inputType == OperandType::TENSOR_FLOAT32 ||
2619                 inputType == OperandType::TENSOR_INT32 ||
2620                 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2621                 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2622                 inExpectedTypes = {inputType, OperandType::TENSOR_INT32};
2623                 outExpectedTypes = {inputType};
2624             } else {
2625                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2626             }
2627             Version version;
2628             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2629                 version = Version::ANDROID_R;
2630             } else {
2631                 version = Version::ANDROID_Q;
2632             }
2633             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2634                                                  outputIndexes, outExpectedTypes));
2635             return version;
2636         }
2637         case OperationType::POW: {
2638             NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2639                     << invalidInOutNumberMessage(2, 1);
2640             auto inputType = operands[inputIndexes[0]].type;
2641             std::vector<OperandType> inExpectedTypes;
2642             std::vector<OperandType> outExpectedTypes;
2643             if (inputType == OperandType::TENSOR_FLOAT16 ||
2644                 inputType == OperandType::TENSOR_FLOAT32) {
2645                 inExpectedTypes = {inputType, inputType};
2646                 outExpectedTypes = {inputType};
2647             } else {
2648                 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2649             }
2650             Version version;
2651             if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2652                 version = Version::ANDROID_R;
2653             } else {
2654                 version = Version::ANDROID_Q;
2655             }
2656             NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2657                                                  outputIndexes, outExpectedTypes));
2658             return version;
2659         }
2660         case OperationType::IF: {
2661             return validateIfOperation(inputIndexes, outputIndexes, operands, subgraphs);
2662         }
2663         case OperationType::WHILE: {
2664             return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs);
2665         }
2666         default: {
2667             const OperationRegistration* operationRegistration =
2668                     BuiltinOperationResolver::get()->findOperation(
2669                             static_cast<OperationType>(opType));
2670             // TODO: return ErrorStatus::UNEXPECTED_NULL
2671             NN_VALIDATE(operationRegistration != nullptr) << opType << " not registered";
2672             // TODO: return ErrorStatus::UNEXPECTED_NULL
2673             NN_VALIDATE(operationRegistration->validate != nullptr)
2674                     << "Incomplete operation registration: " << opType;
2675 
2676             OperationValidationContext context(operationRegistration->name, inputIndexes,
2677                                                outputIndexes, operands);
2678             return operationRegistration->validate(&context);
2679         }
2680     }
2681 }
2682 
validateOperationIncludingOperandVersions(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)2683 Result<Version> validateOperationIncludingOperandVersions(
2684         const Operation& operation, const std::vector<Operand>& operands,
2685         const std::vector<Version>& operandVersions,
2686         const std::vector<Model::Subgraph>& subgraphs) {
2687     auto version = NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
2688     for (uint32_t index : operation.inputs) {
2689         version = combineVersions(version, operandVersions[index]);
2690     }
2691     for (uint32_t index : operation.outputs) {
2692         version = combineVersions(version, operandVersions[index]);
2693     }
2694     return version;
2695 }
2696 
2697 }  // anonymous namespace
2698 
2699 // Below this point are all the functions that are declared in Validation.h. The convention of this
2700 // file is to keep the function bodies of the functions declared in Validation.h minimal, meaning
2701 // that most functions below simply redirect to one of the functions defined above in the anonymous
2702 // namespace. If there is a function name clash between one of the functions below and one of the
2703 // functions above, the function in the anonymous namespace is appended with "Impl".
2704 
combineVersions(Version lhs,Version rhs)2705 Version combineVersions(Version lhs, Version rhs) {
2706     return std::max<Version>(lhs, rhs);
2707 }
2708 
validate(const DeviceStatus & deviceStatus)2709 Result<Version> validate(const DeviceStatus& deviceStatus) {
2710     return validateDeviceStatus(deviceStatus);
2711 }
2712 
validate(const ExecutionPreference & executionPreference)2713 Result<Version> validate(const ExecutionPreference& executionPreference) {
2714     return validateExecutionPreference(executionPreference);
2715 }
2716 
validate(const DeviceType & deviceType)2717 Result<Version> validate(const DeviceType& deviceType) {
2718     return validateDeviceType(deviceType);
2719 }
2720 
validate(const MeasureTiming & measureTiming)2721 Result<Version> validate(const MeasureTiming& measureTiming) {
2722     return validateMeasureTiming(measureTiming);
2723 }
2724 
validate(const OperandType & operandType)2725 Result<Version> validate(const OperandType& operandType) {
2726     return validateOperandType(operandType);
2727 }
2728 
validate(const Priority & priority)2729 Result<Version> validate(const Priority& priority) {
2730     return validatePriority(priority);
2731 }
2732 
validate(const ErrorStatus & errorStatus)2733 Result<Version> validate(const ErrorStatus& errorStatus) {
2734     return validateErrorStatus(errorStatus);
2735 }
2736 
validate(const FusedActivationFunc & activation)2737 Result<Version> validate(const FusedActivationFunc& activation) {
2738     return validateFusedActivationFunc(activation);
2739 }
2740 
validate(const OutputShape & outputShape)2741 Result<Version> validate(const OutputShape& outputShape) {
2742     return validateOutputShape(outputShape);
2743 }
2744 
validate(const Timing & timing)2745 Result<Version> validate(const Timing& timing) {
2746     return validateTiming(timing);
2747 }
2748 
validate(const Capabilities & capabilities)2749 Result<Version> validate(const Capabilities& capabilities) {
2750     return validateCapabilities(capabilities);
2751 }
2752 
validate(const Extension & extension)2753 Result<Version> validate(const Extension& extension) {
2754     return validateExtension(extension);
2755 }
2756 
validate(const SharedHandle & handle)2757 Result<Version> validate(const SharedHandle& handle) {
2758     return validateSharedHandle(handle);
2759 }
2760 
validate(const SharedMemory & memory)2761 Result<Version> validate(const SharedMemory& memory) {
2762     return validateSharedMemory(memory);
2763 }
2764 
validate(const Model & model)2765 Result<Version> validate(const Model& model) {
2766     return validateModel(model);
2767 }
2768 
validate(const BufferDesc & bufferDesc)2769 Result<Version> validate(const BufferDesc& bufferDesc) {
2770     return validateBufferDesc(bufferDesc);
2771 }
2772 
validate(const BufferRole & bufferRole)2773 Result<Version> validate(const BufferRole& bufferRole) {
2774     return validateBufferRole(bufferRole);
2775 }
2776 
validate(const Request & request)2777 Result<Version> validate(const Request& request) {
2778     return validateRequest(request);
2779 }
2780 
validate(const OptionalTimePoint & optionalTimePoint)2781 Result<Version> validate(const OptionalTimePoint& optionalTimePoint) {
2782     return validateOptionalTimePoint(optionalTimePoint);
2783 }
2784 
validate(const OptionalDuration & optionalTimeoutDuration)2785 Result<Version> validate(const OptionalDuration& optionalTimeoutDuration) {
2786     return validateOptionalTimeoutDuration(optionalTimeoutDuration);
2787 }
2788 
validate(const CacheToken & cacheToken)2789 Result<Version> validate(const CacheToken& cacheToken) {
2790     return validateCacheToken(cacheToken);
2791 }
2792 
validate(const SyncFence & syncFence)2793 Result<Version> validate(const SyncFence& syncFence) {
2794     return validateSyncFence(syncFence);
2795 }
2796 
validate(const std::vector<OutputShape> & outputShapes)2797 Result<Version> validate(const std::vector<OutputShape>& outputShapes) {
2798     return validateVector(outputShapes, validateOutputShape);
2799 }
2800 
validate(const std::vector<Extension> & extensions)2801 Result<Version> validate(const std::vector<Extension>& extensions) {
2802     return validateExtensions(extensions);
2803 }
2804 
validate(const std::vector<SharedHandle> & handles)2805 Result<Version> validate(const std::vector<SharedHandle>& handles) {
2806     return validateVector(handles, validateSharedHandle);
2807 }
2808 
validate(const std::vector<BufferRole> & bufferRoles)2809 Result<Version> validate(const std::vector<BufferRole>& bufferRoles) {
2810     return validateVector(bufferRoles, validateBufferRole);
2811 }
2812 
validate(const std::vector<SyncFence> & syncFences)2813 Result<Version> validate(const std::vector<SyncFence>& syncFences) {
2814     return validateVector(syncFences, validateSyncFence);
2815 }
2816 
validateRequestForModel(const Request & request,const Model & model,bool allowUnspecifiedOutput)2817 Result<Version> validateRequestForModel(const Request& request, const Model& model,
2818                                         bool allowUnspecifiedOutput) {
2819     return validateRequestForModelImpl(request, model, allowUnspecifiedOutput);
2820 }
2821 
validateMemoryDesc(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)2822 Result<Version> validateMemoryDesc(
2823         const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
2824         const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
2825         const std::function<const Model*(const SharedPreparedModel&)>& getModel,
2826         std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
2827     return validateMemoryDescImpl(desc, preparedModels, inputRoles, outputRoles, getModel,
2828                                   preparedModelRoles, combinedOperand);
2829 }
2830 
validateOperandSymmPerChannelQuantParams(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)2831 Result<void> validateOperandSymmPerChannelQuantParams(
2832         const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
2833         const char* tag) {
2834     return validateOperandSymmPerChannelQuantParamsImpl(operand, channelQuant, tag);
2835 }
2836 
validateOperandType(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)2837 Result<void> validateOperandType(const Operand& type,
2838                                  const Extension::OperandTypeInformation* extensionOperandTypeInfo,
2839                                  const char* tag, bool allowPartial) {
2840     return validateOperandTypeImpl(type, extensionOperandTypeInfo, tag, allowPartial);
2841 }
2842 
validateOperandList(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)2843 Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount,
2844                                  const char* tag) {
2845     return validateOperandListImpl(list, operandCount, tag);
2846 }
2847 
validateOperationButNotOperands(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)2848 Result<void> validateOperationButNotOperands(const Operation& operation,
2849                                              const std::vector<Operand>& operands,
2850                                              const std::vector<Model::Subgraph>& subgraphs) {
2851     NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
2852     return {};
2853 }
2854 
2855 struct SubgraphVersionCache {
2856     std::vector<std::optional<Version>> cache;
2857 };
2858 
createSubgraphVersionCache(size_t referencedSubgraphCount)2859 std::unique_ptr<SubgraphVersionCache, void (*)(SubgraphVersionCache*)> createSubgraphVersionCache(
2860         size_t referencedSubgraphCount) {
2861     auto subgraphVersionCache = std::make_unique<SubgraphVersionCache>();
2862     subgraphVersionCache->cache.resize(referencedSubgraphCount);
2863     constexpr auto deleter = [](SubgraphVersionCache* ptr) { delete ptr; };
2864     return {subgraphVersionCache.release(), deleter};
2865 }
2866 
validateOperationAndAnythingItDependsOn(const Operation & operation,const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2867 Result<Version> validateOperationAndAnythingItDependsOn(
2868         const Operation& operation, const std::vector<Operand>& operands, size_t operandValuesSize,
2869         const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
2870         SubgraphVersionCache* subgraphVersionCache) {
2871     CHECK(subgraphVersionCache != nullptr);
2872     std::vector<Version> operandVersions(operands.size(), Version::ANDROID_OC_MR1);
2873     for (uint32_t index : operation.inputs) {
2874         NN_VALIDATE_LT(index, operands.size());
2875         const Operand& operand = operands[index];
2876         operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2877                 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2878     }
2879     for (uint32_t index : operation.outputs) {
2880         NN_VALIDATE_LT(index, operands.size());
2881         const Operand& operand = operands[index];
2882         operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2883                 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2884     }
2885     return validateOperationIncludingOperandVersions(operation, operands, operandVersions,
2886                                                      subgraphs);
2887 }
2888 
validateOperandAndAnythingItDependsOn(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2889 Result<Version> validateOperandAndAnythingItDependsOn(const Operand& operand,
2890                                                       size_t operandValuesSize,
2891                                                       const std::vector<size_t>& poolSizes,
2892                                                       const std::vector<Model::Subgraph>& subgraphs,
2893                                                       SubgraphVersionCache* subgraphVersionCache) {
2894     CHECK(subgraphVersionCache != nullptr);
2895     return validateOperand(operand, operandValuesSize, poolSizes, subgraphs,
2896                            &subgraphVersionCache->cache);
2897 }
2898 
2899 }  // namespace android::nn
2900