1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "Validation.h"
18
19 #include <android-base/logging.h>
20 #include <android-base/mapped_file.h>
21
22 #include <algorithm>
23 #include <cctype>
24 #include <functional>
25 #include <limits>
26 #include <memory>
27 #include <numeric>
28 #include <set>
29 #include <sstream>
30 #include <string>
31 #include <string_view>
32 #include <tuple>
33 #include <utility>
34 #include <variant>
35 #include <vector>
36
37 #include "ControlFlow.h"
38 #include "OperandTypes.h"
39 #include "OperationResolver.h"
40 #include "OperationTypes.h"
41 #include "Result.h"
42 #include "SharedMemory.h"
43 #include "TypeUtils.h"
44 #include "Types.h"
45
46 // The NN_VALIDATE family of macros defined below is similar to the CHECK family defined in
47 // system/libbase/include/android-base/logging.h
48 //
49 // The difference is that NN_VALIDATE macros use LOG(ERROR) instead of LOG(FATAL)
50 // and return false instead of aborting.
51
52 // Logs an error and returns false or INVALID. Append context using << after. For example:
53 //
54 // NN_VALIDATE_FAIL() << "Something went wrong";
55 //
56 // The containing function must return a bool or Version.
57 #define NN_VALIDATE_FAIL() \
58 return NN_ERROR() << "NN_VALIDATE failed (" << __FILE__ << ":" << __LINE__ << "): "
59
60 // Logs an error and returns false or Version::INVALID if condition is false. Extra logging can be
61 // appended using << after. For example:
62 //
63 // NN_VALIDATE(false) << "Something went wrong";
64 //
65 // The containing function must return a bool.
66 #define NN_VALIDATE(condition) \
67 while (UNLIKELY(!(condition))) NN_VALIDATE_FAIL() << #condition << " "
68
69 // Helper for NN_VALIDATE_xx(x, y) macros.
70 #define NN_VALIDATE_OP(LHS, RHS, OP) \
71 for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \
72 UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \
73 /* empty */) \
74 NN_VALIDATE_FAIL() \
75 << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \
76 << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
77 << ", " << #RHS << " = " \
78 << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
79 << ") "
80
81 // Logs an error and returns false or Version::INVALID if a condition between x and y does not hold.
82 // Extra logging can be appended using << after. For example:
83 //
84 // NN_VALIDATE_EQ(a, b) << "Something went wrong";
85 //
86 // The values must implement the appropriate comparison operator as well as
87 // `operator<<(std::ostream&, ...)`.
88 // The containing function must return a bool or Version.
89 #define NN_VALIDATE_EQ(x, y) NN_VALIDATE_OP(x, y, ==)
90 #define NN_VALIDATE_NE(x, y) NN_VALIDATE_OP(x, y, !=)
91 #define NN_VALIDATE_LE(x, y) NN_VALIDATE_OP(x, y, <=)
92 #define NN_VALIDATE_LT(x, y) NN_VALIDATE_OP(x, y, <)
93 #define NN_VALIDATE_GE(x, y) NN_VALIDATE_OP(x, y, >=)
94 #define NN_VALIDATE_GT(x, y) NN_VALIDATE_OP(x, y, >)
95
96 namespace android::nn {
97 namespace {
98
99 constexpr auto kNullptrVariant = std::variant<const void*, void*>{};
100 constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{};
101
102 template <typename Type, typename ValidationFunction>
validateVector(const std::vector<Type> & objects,const ValidationFunction & validationFunction)103 Result<Version> validateVector(const std::vector<Type>& objects,
104 const ValidationFunction& validationFunction) {
105 auto version = Version::ANDROID_OC_MR1;
106 for (const auto& object : objects) {
107 version = combineVersions(version, NN_TRY(validationFunction(object)));
108 }
109 return version;
110 }
111
isValidExtensionName(const std::string & name)112 bool isValidExtensionName(const std::string& name) {
113 constexpr auto validSymbol = [](char symbol) {
114 return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_';
115 };
116 const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol);
117 const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end();
118 return hasOnlyValidSymbols && hasAtLeastOnePeriod;
119 }
120
validateDeviceStatus(const DeviceStatus & deviceStatus)121 Result<Version> validateDeviceStatus(const DeviceStatus& deviceStatus) {
122 switch (deviceStatus) {
123 case DeviceStatus::AVAILABLE:
124 case DeviceStatus::BUSY:
125 case DeviceStatus::OFFLINE:
126 case DeviceStatus::UNKNOWN:
127 return Version::ANDROID_OC_MR1;
128 }
129 NN_VALIDATE_FAIL() << "Invalid DeviceStatus " << deviceStatus;
130 }
131
validateExecutionPreference(const ExecutionPreference & executionPreference)132 Result<Version> validateExecutionPreference(const ExecutionPreference& executionPreference) {
133 switch (executionPreference) {
134 case ExecutionPreference::FAST_SINGLE_ANSWER:
135 // ExecutionPreference::FAST_SINGLE_ANSWER is the default value, so it is implicitly
136 // valid for all versions.
137 return Version::ANDROID_OC_MR1;
138 case ExecutionPreference::LOW_POWER:
139 case ExecutionPreference::SUSTAINED_SPEED:
140 return Version::ANDROID_P;
141 }
142 NN_VALIDATE_FAIL() << "Invalid ExecutionPreference " << executionPreference;
143 }
144
validateDeviceType(const DeviceType & deviceType)145 Result<Version> validateDeviceType(const DeviceType& deviceType) {
146 switch (deviceType) {
147 case DeviceType::UNKNOWN:
148 // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when
149 // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a
150 // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a
151 // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for
152 // DeviceType::UNKNOWN.
153 return Version::ANDROID_OC_MR1;
154 case DeviceType::OTHER:
155 case DeviceType::CPU:
156 case DeviceType::GPU:
157 case DeviceType::ACCELERATOR:
158 return Version::ANDROID_Q;
159 }
160 NN_VALIDATE_FAIL() << "Invalid DeviceType " << deviceType;
161 }
162
validateMeasureTiming(const MeasureTiming & measureTiming)163 Result<Version> validateMeasureTiming(const MeasureTiming& measureTiming) {
164 switch (measureTiming) {
165 case MeasureTiming::NO:
166 // MeasureTiming::NO is the default value, so it is implicitly valid for all versions.
167 return Version::ANDROID_OC_MR1;
168 case MeasureTiming::YES:
169 return Version::ANDROID_Q;
170 }
171 NN_VALIDATE_FAIL() << "Invalid MeasureTiming " << measureTiming;
172 }
173
validateOperandType(const OperandType & operandType)174 Result<Version> validateOperandType(const OperandType& operandType) {
175 switch (operandType) {
176 case OperandType::FLOAT32:
177 case OperandType::INT32:
178 case OperandType::UINT32:
179 case OperandType::TENSOR_FLOAT32:
180 case OperandType::TENSOR_INT32:
181 case OperandType::TENSOR_QUANT8_ASYMM:
182 case OperandType::OEM:
183 case OperandType::TENSOR_OEM_BYTE:
184 return Version::ANDROID_OC_MR1;
185 case OperandType::BOOL:
186 case OperandType::TENSOR_QUANT16_SYMM:
187 case OperandType::TENSOR_FLOAT16:
188 case OperandType::TENSOR_BOOL8:
189 case OperandType::FLOAT16:
190 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
191 case OperandType::TENSOR_QUANT16_ASYMM:
192 case OperandType::TENSOR_QUANT8_SYMM:
193 return Version::ANDROID_Q;
194 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
195 case OperandType::SUBGRAPH:
196 return Version::ANDROID_R;
197 }
198 if (isExtension(operandType)) {
199 return Version::ANDROID_Q;
200 }
201 NN_VALIDATE_FAIL() << "Invalid OperandType " << operandType;
202 }
203
validateOperandLifeTime(const Operand & operand)204 Result<Version> validateOperandLifeTime(const Operand& operand) {
205 // Make sure SUBGRAPH operand type and lifetime always go together.
206 NN_VALIDATE_EQ((operand.type == OperandType::SUBGRAPH),
207 (operand.lifetime == Operand::LifeTime::SUBGRAPH))
208 << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime;
209
210 switch (operand.lifetime) {
211 case Operand::LifeTime::TEMPORARY_VARIABLE:
212 case Operand::LifeTime::SUBGRAPH_INPUT:
213 case Operand::LifeTime::SUBGRAPH_OUTPUT:
214 case Operand::LifeTime::CONSTANT_COPY:
215 case Operand::LifeTime::CONSTANT_REFERENCE:
216 case Operand::LifeTime::NO_VALUE:
217 case Operand::LifeTime::POINTER:
218 return Version::ANDROID_OC_MR1;
219 case Operand::LifeTime::SUBGRAPH:
220 return Version::ANDROID_R;
221 }
222 NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
223 }
224
validatePriority(const Priority & priority)225 Result<Version> validatePriority(const Priority& priority) {
226 switch (priority) {
227 case Priority::MEDIUM:
228 // Priority::MEDIUM is the default value, so it is implicitly valid for all versions.
229 return Version::ANDROID_OC_MR1;
230 case Priority::LOW:
231 case Priority::HIGH:
232 return Version::ANDROID_R;
233 }
234 NN_VALIDATE_FAIL() << "Invalid Priority " << priority;
235 }
236
validateErrorStatus(const ErrorStatus & errorStatus)237 Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) {
238 // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih
239 // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE.
240 switch (errorStatus) {
241 case ErrorStatus::NONE:
242 case ErrorStatus::DEVICE_UNAVAILABLE:
243 case ErrorStatus::GENERAL_FAILURE:
244 case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
245 case ErrorStatus::INVALID_ARGUMENT:
246 case ErrorStatus::MISSED_DEADLINE_TRANSIENT:
247 case ErrorStatus::MISSED_DEADLINE_PERSISTENT:
248 case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT:
249 case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT:
250 case ErrorStatus::DEAD_OBJECT:
251 return Version::ANDROID_OC_MR1;
252 }
253 NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus;
254 }
255
validateFusedActivationFunc(const FusedActivationFunc & activation)256 Result<Version> validateFusedActivationFunc(const FusedActivationFunc& activation) {
257 switch (activation) {
258 case FusedActivationFunc::NONE:
259 case FusedActivationFunc::RELU:
260 case FusedActivationFunc::RELU1:
261 case FusedActivationFunc::RELU6:
262 return Version::ANDROID_OC_MR1;
263 }
264 NN_VALIDATE_FAIL() << "Invalid FusedActivationFunc " << activation;
265 }
266
validateOutputShape(const OutputShape &)267 Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) {
268 return Version::ANDROID_Q;
269 }
270
validateTiming(const Timing & timing)271 Result<Version> validateTiming(const Timing& timing) {
272 constexpr auto kNoTiming = Timing{};
273 if (timing == kNoTiming) {
274 // kNoTiming is the default value, so it is implicitly valid for all versions.
275 return Version::ANDROID_OC_MR1;
276 }
277 if (timing.timeInDriver.has_value() && timing.timeOnDevice.has_value()) {
278 // `lazyMessage` is a lazy function to produce the timing validation error message.
279 // Currently, the code is not able to inline the message in NN_VALIDATE due to a
280 // argument-dependent lookup issue with nn::detail::ErrorBuilder interacting with std types
281 // such as std::chrono::duration, so this function uses an indirection through
282 // std::ostringstream.
283 const auto lazyMessage = [&timing]() -> std::string {
284 std::ostringstream oss;
285 oss << "Timing::timeOnDevice (" << timing.timeOnDevice.value()
286 << ") must not exceed Timing::timeInDriver (" << timing.timeInDriver.value() << ")";
287 return oss.str();
288 };
289 NN_VALIDATE(timing.timeOnDevice.value() <= timing.timeInDriver.value()) << lazyMessage();
290 }
291 return Version::ANDROID_Q;
292 }
293
validateCapabilitiesPerformanceInfo(const Capabilities::PerformanceInfo & performanceInfo)294 Result<Version> validateCapabilitiesPerformanceInfo(
295 const Capabilities::PerformanceInfo& performanceInfo) {
296 NN_VALIDATE_GT(performanceInfo.execTime, 0.0f);
297 NN_VALIDATE_GT(performanceInfo.powerUsage, 0.0f);
298 return Version::ANDROID_OC_MR1;
299 }
300
validateCapabilitiesOperandPerformance(const Capabilities::OperandPerformance & operandPerformance)301 Result<Version> validateCapabilitiesOperandPerformance(
302 const Capabilities::OperandPerformance& operandPerformance) {
303 auto version = NN_TRY(validateOperandType(operandPerformance.type));
304 return combineVersions(version,
305 NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info)));
306 }
307
validateCapabilitiesOperandPerformanceTable(const Capabilities::OperandPerformanceTable & operandPerformances)308 Result<Version> validateCapabilitiesOperandPerformanceTable(
309 const Capabilities::OperandPerformanceTable& operandPerformances) {
310 // OperandPerformanceTable's order was validated when it was created, and it is castable to any
311 // version. If an OperandType does not exist in the lower version being converted to, that
312 // OperandPerformance will be dropped.
313 NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance));
314 return Version::ANDROID_OC_MR1;
315 }
316
validateCapabilities(const Capabilities & capabilities)317 Result<Version> validateCapabilities(const Capabilities& capabilities) {
318 auto version =
319 NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance));
320
321 version = combineVersions(version,
322 NN_TRY(validateCapabilitiesPerformanceInfo(
323 capabilities.relaxedFloat32toFloat16PerformanceScalar)));
324 version = combineVersions(version,
325 NN_TRY(validateCapabilitiesPerformanceInfo(
326 capabilities.relaxedFloat32toFloat16PerformanceTensor)));
327 version = combineVersions(
328 version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance)));
329 version = combineVersions(
330 version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance)));
331
332 return version;
333 }
334
validateExtensionOperandTypeInformation(const Extension::OperandTypeInformation & operandTypeInformation)335 Result<Version> validateExtensionOperandTypeInformation(
336 const Extension::OperandTypeInformation& operandTypeInformation) {
337 NN_VALIDATE_GT(operandTypeInformation.byteSize, 0u);
338 return Version::ANDROID_Q;
339 }
340
validateExtension(const Extension & extension)341 Result<Version> validateExtension(const Extension& extension) {
342 NN_VALIDATE(isValidExtensionName(extension.name));
343
344 // Verify all OperandTypeInformations have unique types.
345 std::vector<uint16_t> types;
346 types.reserve(extension.operandTypes.size());
347 std::transform(extension.operandTypes.begin(), extension.operandTypes.end(),
348 std::back_inserter(types),
349 [](const Extension::OperandTypeInformation& operandTypeInformation) {
350 return operandTypeInformation.type;
351 });
352 std::sort(types.begin(), types.end());
353 const auto iter = std::adjacent_find(types.begin(), types.end());
354 NN_VALIDATE(iter == types.end()) << "Extension has duplicate type " << *iter;
355
356 return combineVersions(Version::ANDROID_Q,
357 NN_TRY(validateVector(extension.operandTypes,
358 validateExtensionOperandTypeInformation)));
359 }
360
validateExtensions(const std::vector<Extension> & extensions)361 Result<Version> validateExtensions(const std::vector<Extension>& extensions) {
362 const auto version = NN_TRY(validateVector(extensions, validateExtension));
363
364 // Verify all extensions have unique names.
365 std::vector<std::reference_wrapper<const std::string>> names;
366 names.reserve(extensions.size());
367 std::transform(extensions.begin(), extensions.end(), std::back_inserter(names),
368 [](const Extension& extension) { return std::cref(extension.name); });
369 std::sort(names.begin(), names.end(), std::less<std::string>{});
370 const auto nameIter =
371 std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
372 NN_VALIDATE(nameIter == names.end())
373 << "Two or more extensions have the duplicate name " << nameIter->get();
374
375 return version;
376 }
377
378 // Forward declaration of subgraph validation function.
379 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
380 std::optional<size_t> referencedIndex,
381 size_t operandValuesSize,
382 const std::vector<size_t>& poolSizes,
383 const std::vector<Model::Subgraph>& referenced,
384 std::vector<std::optional<Version>>* subgraphVersionCache);
385
validateOperandDataLocation(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)386 Result<Version> validateOperandDataLocation(
387 const Operand& operand, size_t operandValuesSize, const std::vector<size_t>& poolSizes,
388 const std::vector<Model::Subgraph>& subgraphs,
389 std::vector<std::optional<Version>>* subgraphVersionCache) {
390 const DataLocation& location = operand.location;
391 NN_VALIDATE_EQ(location.padding, 0u)
392 << "DataLocation with a non-zero padding used in Model: " << location.padding;
393 switch (operand.lifetime) {
394 case Operand::LifeTime::CONSTANT_COPY:
395 NN_VALIDATE(location.pointer == kNullptrVariant)
396 << "CONSTANT_COPY with a non-null pointer";
397 NN_VALIDATE_EQ(location.poolIndex, 0u)
398 << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex;
399 // Do the addition using uint64_t to avoid potential wrap-around problems.
400 NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length,
401 operandValuesSize)
402 << "OperandValue location out of range. Starts at " << location.offset
403 << ", length " << location.length << ", max " << operandValuesSize;
404 return Version::ANDROID_OC_MR1;
405 case Operand::LifeTime::CONSTANT_REFERENCE:
406 NN_VALIDATE_LT(location.poolIndex, poolSizes.size());
407 // Do the addition using uint64_t to avoid potential wrap-around problems.
408 NN_VALIDATE_LE(static_cast<uint64_t>(location.offset) + location.length,
409 poolSizes[location.poolIndex])
410 << "OperandValue location out of range. Starts at " << location.offset
411 << ", length " << location.length << ", max " << poolSizes[location.poolIndex];
412 return Version::ANDROID_OC_MR1;
413 case Operand::LifeTime::TEMPORARY_VARIABLE:
414 case Operand::LifeTime::SUBGRAPH_INPUT:
415 case Operand::LifeTime::SUBGRAPH_OUTPUT:
416 case Operand::LifeTime::NO_VALUE:
417 NN_VALIDATE(location.pointer == kNullptrVariant)
418 << "Unexpected pointer value for operand of lifetime " << operand.lifetime;
419 NN_VALIDATE_EQ(location.poolIndex, 0u)
420 << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime "
421 << operand.lifetime;
422 NN_VALIDATE_EQ(location.offset, 0u) << "Unexpected offset " << location.offset
423 << " for operand of lifetime " << operand.lifetime;
424 NN_VALIDATE_EQ(location.length, 0u) << "Unexpected length " << location.length
425 << " for operand of lifetime " << operand.lifetime;
426 return Version::ANDROID_OC_MR1;
427 case Operand::LifeTime::SUBGRAPH: {
428 NN_VALIDATE(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer";
429 NN_VALIDATE_EQ(location.poolIndex, 0u)
430 << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex;
431 NN_VALIDATE_LT(location.offset, subgraphs.size())
432 << "Subgraph index out of range: " << location.offset
433 << " >= " << subgraphs.size();
434 NN_VALIDATE_EQ(location.length, 0u)
435 << "SUBGRAPH with a non-zero length " << location.length;
436 const auto version = NN_TRY(validateModelSubgraph(
437 subgraphs[location.offset], location.offset, operandValuesSize, poolSizes,
438 subgraphs, subgraphVersionCache));
439 return combineVersions(version, Version::ANDROID_R);
440 }
441 case Operand::LifeTime::POINTER: {
442 const bool nonNull =
443 std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer);
444 NN_VALIDATE(nonNull) << "POINTER with a null pointer";
445 NN_VALIDATE_EQ(location.poolIndex, 0u)
446 << "POINTER with a non-zero poolIndex " << location.poolIndex;
447 NN_VALIDATE_EQ(location.offset, 0u)
448 << "POINTER with a non-zero offset " << location.offset;
449 return Version::ANDROID_OC_MR1;
450 }
451 }
452 NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime;
453 }
454
validateOperandDimensions(const Operand & operand)455 Result<Version> validateOperandDimensions(const Operand& operand) {
456 switch (operand.type) {
457 case OperandType::FLOAT32:
458 case OperandType::INT32:
459 case OperandType::UINT32:
460 case OperandType::BOOL:
461 case OperandType::FLOAT16:
462 case OperandType::SUBGRAPH:
463 case OperandType::OEM:
464 NN_VALIDATE(operand.dimensions.empty())
465 << "Scalar data has dimensions of rank " << operand.dimensions.size();
466 return Version::ANDROID_OC_MR1;
467 case OperandType::TENSOR_FLOAT32:
468 case OperandType::TENSOR_INT32:
469 case OperandType::TENSOR_QUANT8_ASYMM:
470 case OperandType::TENSOR_QUANT16_SYMM:
471 case OperandType::TENSOR_FLOAT16:
472 case OperandType::TENSOR_BOOL8:
473 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
474 case OperandType::TENSOR_QUANT16_ASYMM:
475 case OperandType::TENSOR_QUANT8_SYMM:
476 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
477 case OperandType::TENSOR_OEM_BYTE: {
478 if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
479 operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
480 operand.lifetime == Operand::LifeTime::POINTER) {
481 NN_VALIDATE(!operand.dimensions.empty())
482 << "Tensor has lifetime of " << operand.lifetime
483 << " but dimensions of rank 0";
484 const auto size = getNonExtensionSize(operand);
485 NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow";
486 NN_VALIDATE_NE(size.value(), 0u) << "Tensor has at least one unknown dimension";
487 }
488 // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before
489 // Android Q?
490 if (operand.dimensions.empty()) {
491 // Unspecified rank was added in Android Q.
492 return Version::ANDROID_Q;
493 }
494 return Version::ANDROID_OC_MR1;
495 }
496 }
497 if (isExtension(operand.type)) {
498 // Extension types were added in Android Q.
499 return Version::ANDROID_Q;
500 }
501 NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
502 }
503
validateOperandScale(const Operand & operand)504 Result<Version> validateOperandScale(const Operand& operand) {
505 switch (operand.type) {
506 case OperandType::FLOAT32:
507 case OperandType::INT32:
508 case OperandType::UINT32:
509 case OperandType::TENSOR_FLOAT32:
510 case OperandType::BOOL:
511 case OperandType::TENSOR_FLOAT16:
512 case OperandType::TENSOR_BOOL8:
513 case OperandType::FLOAT16:
514 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
515 case OperandType::SUBGRAPH:
516 NN_VALIDATE_EQ(operand.scale, 0.0f)
517 << "Operand of type " << operand.type << " with a non-zero scale ("
518 << operand.scale << ")";
519 return Version::ANDROID_OC_MR1;
520 case OperandType::TENSOR_INT32:
521 // TENSOR_INT32 may be used with or without scale, depending on the operation.
522 // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale.
523 NN_VALIDATE_GE(operand.scale, 0.0f)
524 << "Operand of type " << operand.type << " with a negative scale";
525 return Version::ANDROID_OC_MR1;
526 case OperandType::TENSOR_QUANT8_ASYMM:
527 case OperandType::TENSOR_QUANT16_SYMM:
528 case OperandType::TENSOR_QUANT16_ASYMM:
529 case OperandType::TENSOR_QUANT8_SYMM:
530 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
531 NN_VALIDATE_GT(operand.scale, 0.0f)
532 << "Operand of type " << operand.type << " with a non-positive scale";
533 return Version::ANDROID_OC_MR1;
534 case OperandType::OEM:
535 case OperandType::TENSOR_OEM_BYTE:
536 // No validation for OEM types.
537 return Version::ANDROID_OC_MR1;
538 }
539 if (isExtension(operand.type)) {
540 NN_VALIDATE_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type
541 << " with a non-zero scale (" << operand.scale << ")";
542 return Version::ANDROID_Q;
543 }
544 NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
545 }
546
validateOperandZeroPoint(const Operand & operand)547 Result<Version> validateOperandZeroPoint(const Operand& operand) {
548 switch (operand.type) {
549 case OperandType::FLOAT32:
550 case OperandType::INT32:
551 case OperandType::UINT32:
552 case OperandType::TENSOR_FLOAT32:
553 case OperandType::TENSOR_INT32:
554 case OperandType::BOOL:
555 case OperandType::TENSOR_FLOAT16:
556 case OperandType::TENSOR_BOOL8:
557 case OperandType::FLOAT16:
558 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
559 case OperandType::TENSOR_QUANT8_SYMM:
560 case OperandType::SUBGRAPH:
561 NN_VALIDATE_EQ(operand.zeroPoint, 0)
562 << "Operand of type " << operand.type << " with a non-zero zeroPoint "
563 << operand.zeroPoint;
564 return Version::ANDROID_OC_MR1;
565 case OperandType::TENSOR_QUANT8_ASYMM:
566 NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 255)
567 << "Operand of type " << operand.type << " with an invalid zeroPoint "
568 << operand.zeroPoint << ", must be in range [0, 255]";
569 return Version::ANDROID_OC_MR1;
570 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
571 NN_VALIDATE(operand.zeroPoint >= -128 && operand.zeroPoint <= 127)
572 << "Operand of type " << operand.type << " with an invalid zeroPoint "
573 << operand.zeroPoint << ", must be in range [-128, 127]";
574 return Version::ANDROID_OC_MR1;
575 case OperandType::TENSOR_QUANT16_ASYMM:
576 NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535)
577 << "Operand of type " << operand.type << " with an invalid zeroPoint "
578 << operand.zeroPoint << ", must be in range [0, 65535]";
579 return Version::ANDROID_OC_MR1;
580 case OperandType::TENSOR_QUANT16_SYMM:
581 NN_VALIDATE_EQ(operand.zeroPoint, 0)
582 << "Operand of type " << operand.type << " with a non-zero zeroPoint "
583 << operand.zeroPoint;
584 return Version::ANDROID_OC_MR1;
585 case OperandType::OEM:
586 case OperandType::TENSOR_OEM_BYTE:
587 // No validation for OEM types.
588 return Version::ANDROID_OC_MR1;
589 }
590 if (isExtension(operand.type)) {
591 NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type
592 << " with a non-zero zeroPoint " << operand.zeroPoint;
593 return Version::ANDROID_Q;
594 }
595 NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
596 }
597
validateOperandExtraParams(const Operand & operand)598 Result<Version> validateOperandExtraParams(const Operand& operand) {
599 switch (operand.type) {
600 case OperandType::FLOAT32:
601 case OperandType::INT32:
602 case OperandType::UINT32:
603 case OperandType::TENSOR_FLOAT32:
604 case OperandType::TENSOR_INT32:
605 case OperandType::TENSOR_QUANT8_ASYMM:
606 case OperandType::BOOL:
607 case OperandType::TENSOR_QUANT16_SYMM:
608 case OperandType::TENSOR_FLOAT16:
609 case OperandType::TENSOR_BOOL8:
610 case OperandType::FLOAT16:
611 case OperandType::TENSOR_QUANT16_ASYMM:
612 case OperandType::TENSOR_QUANT8_SYMM:
613 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
614 case OperandType::SUBGRAPH:
615 NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams))
616 << "Operand of type " << operand.type
617 << " has extraParams when there must be none";
618 return Version::ANDROID_OC_MR1;
619 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
620 NN_VALIDATE(
621 std::holds_alternative<Operand::SymmPerChannelQuantParams>(operand.extraParams))
622 << "Operand of type " << operand.type
623 << " without a Channel Quantization params";
624 const auto& channelQuant =
625 std::get<Operand::SymmPerChannelQuantParams>(operand.extraParams);
626
627 const size_t count = operand.dimensions.size();
628 NN_VALIDATE_LT(channelQuant.channelDim, count)
629 << "Operand of type " << operand.type
630 << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
631 << ", must be valid dimension index in range [0, " << count << ")";
632 const uint32_t expected = operand.dimensions[channelQuant.channelDim];
633 NN_VALIDATE_EQ(channelQuant.scales.size(), expected)
634 << "Operand of type " << operand.type << " with a wrong-sized scales, expected "
635 << expected << " was " << channelQuant.scales.size();
636 NN_VALIDATE_NE(expected, 0u)
637 << "Operand of type " << operand.type << " channel dimension "
638 << channelQuant.channelDim << " is underspecified (can't be 0)";
639 for (uint32_t i = 0; i < expected; ++i) {
640 NN_VALIDATE_GT(channelQuant.scales[i], 0.0f)
641 << "Operand of type " << operand.type
642 << " with a non-positive value in scales[" << i
643 << "]=" << channelQuant.scales[i];
644 }
645 return Version::ANDROID_Q;
646 }
647 case OperandType::OEM:
648 case OperandType::TENSOR_OEM_BYTE:
649 // No validation for OEM types.
650 return Version::ANDROID_OC_MR1;
651 }
652 if (isExtension(operand.type)) {
653 NN_VALIDATE(std::holds_alternative<Operand::NoParams>(operand.extraParams) ||
654 std::holds_alternative<Operand::ExtensionParams>(operand.extraParams))
655 << "Extension operand of type " << operand.type
656 << " must not have SymmPerChannelQuant extraParams";
657 return Version::ANDROID_OC_MR1;
658 }
659 NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type;
660 }
661
validateOperand(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)662 Result<Version> validateOperand(const Operand& operand, size_t operandValuesSize,
663 const std::vector<size_t>& poolSizes,
664 const std::vector<Model::Subgraph>& subgraphs,
665 std::vector<std::optional<Version>>* subgraphVersionCache) {
666 auto version = NN_TRY(validateOperandType(operand.type));
667 version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand)));
668 version = combineVersions(version, NN_TRY(validateOperandDimensions(operand)));
669 version = combineVersions(version, NN_TRY(validateOperandScale(operand)));
670 version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand)));
671 version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand)));
672 version = combineVersions(
673 version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize, poolSizes,
674 subgraphs, subgraphVersionCache)));
675
676 // For constants, validate that the length is as expected. The other lifetimes
677 // expect the length to be 0. Don't validate for OEM types.
678 if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE ||
679 operand.lifetime == Operand::LifeTime::CONSTANT_COPY ||
680 operand.lifetime == Operand::LifeTime::POINTER) {
681 if (!isExtension(operand.type) && operand.type != OperandType::OEM &&
682 operand.type != OperandType::TENSOR_OEM_BYTE) {
683 const auto expectedLength = getNonExtensionSize(operand).value();
684 NN_VALIDATE_EQ(operand.location.length, expectedLength)
685 << "For operand " << operand.type << " expected a size of " << expectedLength
686 << " but got " << operand.location.length;
687 }
688 }
689
690 return version;
691 }
692
validateOperands(const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,std::vector<std::optional<Version>> * subgraphVersionCache)693 Result<std::vector<Version>> validateOperands(
694 const std::vector<Operand>& operands, size_t operandValuesSize,
695 const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
696 std::vector<std::optional<Version>>* subgraphVersionCache) {
697 std::vector<Version> versions;
698 versions.reserve(operands.size());
699 for (size_t i = 0; i < operands.size(); ++i) {
700 auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphs,
701 subgraphVersionCache);
702 if (!result.has_value()) {
703 return error() << std::move(result).error() << " for operand " << i;
704 }
705 versions.push_back(result.value());
706 }
707 return versions;
708 }
709
710 // Forward declaration.
711 Result<Version> validateOperationIncludingOperandVersions(
712 const Operation& operation, const std::vector<Operand>& operands,
713 const std::vector<Version>& operandVersions, const std::vector<Model::Subgraph>& subgraphs);
714
validateOperations(const std::vector<Operation> & operations,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)715 Result<Version> validateOperations(const std::vector<Operation>& operations,
716 const std::vector<Operand>& operands,
717 const std::vector<Version>& operandVersions,
718 const std::vector<Model::Subgraph>& subgraphs) {
719 auto version = Version::ANDROID_OC_MR1;
720 for (size_t i = 0; i < operations.size(); ++i) {
721 auto result = validateOperationIncludingOperandVersions(operations[i], operands,
722 operandVersions, subgraphs);
723 if (!result.has_value()) {
724 return error() << std::move(result).error() << " for operation " << i;
725 }
726 version = combineVersions(version, result.value());
727 }
728 return version;
729 }
730
validateHandle(const Handle & handle)731 Result<Version> validateHandle(const Handle& handle) {
732 NN_VALIDATE(std::all_of(handle.fds.begin(), handle.fds.end(),
733 [](const base::unique_fd& fd) { return fd.ok(); }));
734 return Version::ANDROID_OC_MR1;
735 }
736
validateSharedHandle(const SharedHandle & handle)737 Result<Version> validateSharedHandle(const SharedHandle& handle) {
738 NN_VALIDATE(handle != nullptr);
739 return validateHandle(*handle);
740 }
741
validateMemory(const Memory::Ashmem & memory)742 Result<Version> validateMemory(const Memory::Ashmem& memory) {
743 NN_VALIDATE(memory.fd.ok());
744 NN_VALIDATE_NE(memory.size, 0u);
745 return Version::ANDROID_OC_MR1;
746 }
747
validateMemory(const Memory::Fd & memory)748 Result<Version> validateMemory(const Memory::Fd& memory) {
749 NN_VALIDATE(memory.fd.ok());
750 NN_VALIDATE_NE(memory.size, 0u);
751
752 // `prot` is allowed to be either PROT_NONE (which has a value of 0) or the bitwise OR of either
753 // PROT_READ or PROT_WRITE. If any other bits are set, the `prot` field is invalid.
754 constexpr int kAllowedBits = PROT_READ | PROT_WRITE;
755 NN_VALIDATE_EQ(memory.prot & ~kAllowedBits, 0);
756
757 return Version::ANDROID_OC_MR1;
758 }
759
validateMemory(const Memory::HardwareBuffer & memory)760 Result<Version> validateMemory(const Memory::HardwareBuffer& memory) {
761 NN_VALIDATE(memory.handle.get() != nullptr);
762 return Version::ANDROID_Q;
763 }
764
validateMemory(const Memory::Unknown & memory)765 Result<Version> validateMemory(const Memory::Unknown& memory) {
766 NN_TRY(validateHandle(memory.handle));
767 return Version::ANDROID_Q;
768 }
769
validateSharedMemory(const SharedMemory & memory)770 Result<Version> validateSharedMemory(const SharedMemory& memory) {
771 NN_VALIDATE(memory != nullptr);
772 return std::visit([](const auto& x) { return validateMemory(x); }, memory->handle);
773 }
774
validateModelSubgraphInputOutputs(const std::vector<uint32_t> & indexes,const std::vector<Operand> & operands,Operand::LifeTime lifetime)775 Result<void> validateModelSubgraphInputOutputs(const std::vector<uint32_t>& indexes,
776 const std::vector<Operand>& operands,
777 Operand::LifeTime lifetime) {
778 const size_t operandCount = operands.size();
779 for (uint32_t i : indexes) {
780 NN_VALIDATE_LT(i, operandCount)
781 << "Model " << lifetime << " input or output index out of range: " << i << "/"
782 << operandCount;
783 const Operand& operand = operands[i];
784 NN_VALIDATE_EQ(operand.lifetime, lifetime)
785 << "Model " << lifetime << " operand " << i << " has lifetime of "
786 << operand.lifetime << " instead of the expected " << lifetime;
787 }
788
789 std::vector<uint32_t> sortedIndexes = indexes;
790 std::sort(sortedIndexes.begin(), sortedIndexes.end());
791 const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end());
792 NN_VALIDATE(iter == sortedIndexes.end())
793 << "Model input or output occurs multiple times: " << *iter;
794
795 for (size_t i = 0; i < operands.size(); ++i) {
796 if (operands[i].lifetime == lifetime) {
797 const auto containsIndex = [&sortedIndexes](size_t index) {
798 return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index);
799 };
800 NN_VALIDATE(containsIndex(i))
801 << "Operand " << i << " marked as " << lifetime
802 << " but is not included in Model input or output indexes";
803 }
804 }
805
806 return {};
807 }
808
validateExecutionOrder(const Model::Subgraph & subgraph)809 Result<void> validateExecutionOrder(const Model::Subgraph& subgraph) {
810 // Either the operand has a known value before model execution begins, or we've seen a writer
811 // for this operand while walking operands in execution order. Initialize to known operands.
812 std::vector<bool> operandValueKnown;
813 operandValueKnown.reserve(subgraph.operands.size());
814 std::transform(subgraph.operands.begin(), subgraph.operands.end(),
815 std::back_inserter(operandValueKnown), [](const Operand& operand) {
816 return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE &&
817 operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT;
818 });
819
820 // Validate that operations are sorted into execution order.
821 //
822 // If there is a cycle in the graph, the operations will not
823 // appear to be sorted into execution order: Some operation will
824 // have an input for which operandValueKnown[] is false.
825 for (size_t i = 0; i < subgraph.operations.size(); ++i) {
826 const auto& operation = subgraph.operations[i];
827
828 for (size_t j = 0; j < operation.inputs.size(); ++j) {
829 const uint32_t k = operation.inputs[j];
830 NN_VALIDATE(operandValueKnown[k]) << "Operation " << i << " input " << j << " (operand "
831 << k << ") is read before it is written";
832 }
833
834 for (size_t j = 0; j < operation.outputs.size(); ++j) {
835 const uint32_t k = operation.outputs[j];
836 // Assuming validateOperations() has not returned an error, we know that this output is
837 // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be
838 // true is if we've already seen a writer for this operand.
839 NN_VALIDATE(!operandValueKnown[k]) << "Operation " << i << " output " << j
840 << " (operand " << k << ") has already been written";
841 operandValueKnown[k] = true;
842 }
843 }
844
845 // Verify all operands are written.
846 for (size_t i = 0; i < subgraph.operands.size(); ++i) {
847 NN_VALIDATE(operandValueKnown[i]) << "Operand " << i << " is never written";
848 }
849
850 // TODO(b/77871786): verify that every operation has at least one output operand that is read?
851
852 return {};
853 }
854
855 // Validate a subgraph, ensuring all subgraphs it depends on are also validated.
856 //
857 // `referencedIndex` is empty if the subgraph being validated is the main subgraph, otherwise it is
858 // the index of the referenced subgraph being validated.
859 //
860 // referenced[i] and (*subgraphVersionCache)[i] correspond to the same subgraph, and therefore
861 // `referenced` and `subgraphVersionCache` must have the same length.
validateModelSubgraph(const Model::Subgraph & subgraph,std::optional<size_t> referencedIndex,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & referenced,std::vector<std::optional<Version>> * subgraphVersionCache)862 Result<Version> validateModelSubgraph(const Model::Subgraph& subgraph,
863 std::optional<size_t> referencedIndex,
864 size_t operandValuesSize,
865 const std::vector<size_t>& poolSizes,
866 const std::vector<Model::Subgraph>& referenced,
867 std::vector<std::optional<Version>>* subgraphVersionCache) {
868 CHECK(subgraphVersionCache != nullptr);
869 CHECK_EQ(referenced.size(), subgraphVersionCache->size());
870
871 // Quickly return if the current subgraph has already been checked for its version.
872 if (referencedIndex.has_value()) {
873 if (auto version = subgraphVersionCache->at(*referencedIndex)) {
874 return *version;
875 }
876 }
877
878 NN_VALIDATE(!subgraph.operands.empty());
879 NN_VALIDATE(!subgraph.operations.empty());
880 // TODO(b/173780642): Clarify whether subgraphs with no inputs or outputs are valid.
881 // NN_VALIDATE(!subgraph.inputIndexes.empty());
882 // NN_VALIDATE(!subgraph.outputIndexes.empty());
883
884 const auto operandVersions = NN_TRY(validateOperands(
885 subgraph.operands, operandValuesSize, poolSizes, referenced, subgraphVersionCache));
886 const auto operationsVersion = NN_TRY(validateOperations(subgraph.operations, subgraph.operands,
887 operandVersions, referenced));
888
889 // Accumulate the versions from all operands and operations.
890 const auto version = std::accumulate(operandVersions.begin(), operandVersions.end(),
891 operationsVersion, combineVersions);
892
893 NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands,
894 Operand::LifeTime::SUBGRAPH_INPUT));
895 NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands,
896 Operand::LifeTime::SUBGRAPH_OUTPUT));
897
898 NN_TRY(validateExecutionOrder(subgraph));
899
900 // Mark the current subgraph as having already been validated so the caller can quickly return
901 // if this subgraph is checked again.
902 if (referencedIndex.has_value()) {
903 subgraphVersionCache->at(*referencedIndex) = version;
904 }
905 return version;
906 }
907
validateModelExtensionNamesAndPrefixes(const std::vector<Model::ExtensionNameAndPrefix> & extensionNamesAndPrefixes)908 Result<Version> validateModelExtensionNamesAndPrefixes(
909 const std::vector<Model::ExtensionNameAndPrefix>& extensionNamesAndPrefixes) {
910 for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) {
911 NN_VALIDATE(isValidExtensionName(extensionNameAndPrefix.name));
912 }
913
914 std::vector<std::reference_wrapper<const std::string>> names;
915 names.reserve(extensionNamesAndPrefixes.size());
916 std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
917 std::back_inserter(names),
918 [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
919 return std::cref(extensionNameAndPrefix.name);
920 });
921 std::sort(names.begin(), names.end(), std::less<std::string>{});
922 const auto nameIter =
923 std::adjacent_find(names.begin(), names.end(), std::equal_to<std::string>{});
924 NN_VALIDATE(nameIter == names.end())
925 << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get();
926
927 std::vector<uint16_t> types;
928 types.reserve(extensionNamesAndPrefixes.size());
929 std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(),
930 std::back_inserter(types),
931 [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) {
932 return extensionNameAndPrefix.prefix;
933 });
934 std::sort(types.begin(), types.end());
935 const auto typeIter = std::adjacent_find(types.begin(), types.end());
936 NN_VALIDATE(typeIter == types.end())
937 << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter;
938
939 const bool hasExtensions = !extensionNamesAndPrefixes.empty();
940 return hasExtensions ? Version::ANDROID_Q : Version::ANDROID_OC_MR1;
941 }
942
943 // Makes sure the model does not contain subgraph reference cycles.
944 //
945 // This function verifies that referencedSubgraphs[subgraphIndex] and any subgraphs it refences do
946 // not contain any reference cycles. `path` is used to keep track of which referenced subgraphs have
947 // already been visited in the current recursive reference path. `verified` is a cache to keep track
948 // of which referenced subgraphs have already been verified not to form reference cycles.
949 //
950 // referencedSubgraphs[i], (*path)[i], and (*verified)[i] all correspond to the same subgraph, and
951 // therefore `referencedSubgraphs`, `path`, and `verified` must all have the same length.
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs,uint32_t subgraphIndex,std::vector<bool> * path,std::vector<bool> * verified)952 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs,
953 uint32_t subgraphIndex, std::vector<bool>* path,
954 std::vector<bool>* verified) {
955 CHECK(path != nullptr);
956 CHECK(verified != nullptr);
957 CHECK_EQ(referencedSubgraphs.size(), path->size());
958 CHECK_EQ(referencedSubgraphs.size(), verified->size());
959 const auto& subgraph = referencedSubgraphs.at(subgraphIndex);
960
961 // Quickly return if the current subgraph has already been verified to have no reference cycles.
962 if ((*verified)[subgraphIndex]) {
963 return {};
964 }
965
966 // Add the current subgraph to the path (making sure that it is not already part of the path),
967 // and verify that all subgraphs this subgraph references do not contain cycles. The current
968 // subgraph is removed from the path only after all subgraphs this subgraph references have been
969 // checked.
970 NN_VALIDATE((*path)[subgraphIndex] == false) << "Model contains a circular subgraph reference";
971 (*path)[subgraphIndex] = true;
972 for (const Operand& operand : subgraph.operands) {
973 if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
974 const uint32_t refSubgraphIndex = operand.location.offset;
975 NN_TRY(checkNoReferenceCycles(referencedSubgraphs, refSubgraphIndex, path, verified));
976 }
977 }
978 (*path)[subgraphIndex] = false;
979
980 // Mark the current subgraph as having already been verified so the caller can quickly return if
981 // this subgraph is checked again.
982 (*verified)[subgraphIndex] = true;
983 return {};
984 }
985
checkNoReferenceCycles(const std::vector<Model::Subgraph> & referencedSubgraphs)986 Result<void> checkNoReferenceCycles(const std::vector<Model::Subgraph>& referencedSubgraphs) {
987 const size_t count = referencedSubgraphs.size();
988 std::vector<bool> path(count);
989 std::vector<bool> verified(count);
990 for (size_t i = 0; i < count; ++i) {
991 NN_TRY(checkNoReferenceCycles(referencedSubgraphs, i, &path, &verified));
992 }
993 return {};
994 }
995
validateModel(const Model & model)996 Result<Version> validateModel(const Model& model) {
997 auto version = NN_TRY(validateVector(model.pools, validateSharedMemory));
998 version = combineVersions(
999 version, NN_TRY(validateModelExtensionNamesAndPrefixes(model.extensionNameToPrefix)));
1000
1001 // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the
1002 // execution stricter.
1003
1004 // Referenced models were introduced in Android R.
1005 const bool hasReferencedModels = !model.referenced.empty();
1006 const auto referenceModelVersion =
1007 hasReferencedModels ? Version::ANDROID_R : Version::ANDROID_OC_MR1;
1008 version = combineVersions(version, referenceModelVersion);
1009
1010 // Ensure that there are no cycles formed by the subgraphs.
1011 NN_TRY(checkNoReferenceCycles(model.referenced));
1012
1013 // Get memory sizes.
1014 const auto [operandValuesSize, poolSizes] = getMemorySizes(model);
1015
1016 // Validate referenced subgraphs.
1017 auto subgraphVersionCache = std::vector<std::optional<Version>>(model.referenced.size());
1018 for (size_t referencedIndex = 0; referencedIndex < model.referenced.size(); ++referencedIndex) {
1019 const auto& subgraph = model.referenced[referencedIndex];
1020 const auto subgraphVersion =
1021 NN_TRY(validateModelSubgraph(subgraph, referencedIndex, operandValuesSize,
1022 poolSizes, model.referenced, &subgraphVersionCache));
1023 version = combineVersions(version, subgraphVersion);
1024 }
1025
1026 // Validate main subgraph.
1027 const auto subgraphVersion =
1028 NN_TRY(validateModelSubgraph(model.main, std::nullopt, operandValuesSize, poolSizes,
1029 model.referenced, &subgraphVersionCache));
1030 version = combineVersions(version, subgraphVersion);
1031
1032 return version;
1033 }
1034
validateBufferDesc(const BufferDesc & bufferDesc)1035 Result<Version> validateBufferDesc(const BufferDesc& bufferDesc) {
1036 // An empty BufferDesc is the default value, so it is implicitly valid for all versions.
1037 return bufferDesc.dimensions.empty() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R;
1038 }
1039
validateBufferRole(const BufferRole & bufferRole)1040 Result<Version> validateBufferRole(const BufferRole& bufferRole) {
1041 NN_VALIDATE_GT(bufferRole.probability, 0.0f);
1042 NN_VALIDATE_LE(bufferRole.probability, 1.0f);
1043 return Version::ANDROID_R;
1044 }
1045
validateRequestArgument(const Request::Argument & requestArgument,const std::vector<size_t> & memorySizes,bool isOutput)1046 Result<Version> validateRequestArgument(const Request::Argument& requestArgument,
1047 const std::vector<size_t>& memorySizes, bool isOutput) {
1048 const auto lifetime = requestArgument.lifetime;
1049 const auto& location = requestArgument.location;
1050 const auto& dimensions = requestArgument.dimensions;
1051
1052 switch (lifetime) {
1053 case Request::Argument::LifeTime::POOL: {
1054 NN_VALIDATE(location.pointer == kNullptrVariant);
1055 NN_VALIDATE_LT(location.poolIndex, memorySizes.size());
1056 // Do the addition using uint64_t to avoid potential wrap-around problems.
1057 const auto lastPosition =
1058 static_cast<uint64_t>(location.offset) + location.length + location.padding;
1059 const auto memorySize = memorySizes[location.poolIndex];
1060 NN_VALIDATE_LE(lastPosition, memorySize);
1061 if (memorySize > 0) {
1062 // Must specify a positive length if the memory pool has a known size.
1063 NN_VALIDATE_GT(location.length, 0u);
1064 }
1065 return Version::ANDROID_OC_MR1;
1066 }
1067 case Request::Argument::LifeTime::NO_VALUE:
1068 NN_VALIDATE(location.pointer == kNullptrVariant);
1069 NN_VALIDATE_EQ(location.poolIndex, 0u);
1070 NN_VALIDATE_EQ(location.offset, 0u);
1071 NN_VALIDATE_EQ(location.length, 0u);
1072 NN_VALIDATE_EQ(location.padding, 0u);
1073 NN_VALIDATE(dimensions.empty());
1074 return Version::ANDROID_OC_MR1;
1075 case Request::Argument::LifeTime::POINTER: {
1076 const bool isNullptr =
1077 std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer);
1078 NN_VALIDATE(!isNullptr);
1079 NN_VALIDATE_EQ(location.poolIndex, 0u);
1080 NN_VALIDATE_EQ(location.offset, 0u);
1081 NN_VALIDATE_NE(location.length, 0u);
1082 if (isOutput) {
1083 NN_VALIDATE(std::holds_alternative<void*>(location.pointer));
1084 }
1085 return Version::ANDROID_OC_MR1;
1086 }
1087 }
1088 NN_VALIDATE_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime;
1089 }
1090
validateRequestMemoryPool(const Request::MemoryPool & memoryPool)1091 Result<Version> validateRequestMemoryPool(const Request::MemoryPool& memoryPool) {
1092 if (std::holds_alternative<Request::MemoryDomainToken>(memoryPool)) {
1093 NN_VALIDATE(std::get<Request::MemoryDomainToken>(memoryPool) != kInvalidMemoryDomainToken);
1094 return Version::ANDROID_R;
1095 }
1096 if (std::holds_alternative<SharedBuffer>(memoryPool)) {
1097 NN_VALIDATE(std::get<SharedBuffer>(memoryPool) != nullptr);
1098 return Version::ANDROID_R;
1099 }
1100 return validateSharedMemory(std::get<SharedMemory>(memoryPool));
1101 }
1102
validateRequest(const Request & request)1103 Result<Version> validateRequest(const Request& request) {
1104 auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool));
1105
1106 // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0.
1107 std::vector<size_t> memorySizes;
1108 memorySizes.reserve(request.pools.size());
1109 std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes),
1110 [](const Request::MemoryPool& memoryPool) {
1111 const auto* memory = std::get_if<SharedMemory>(&memoryPool);
1112 return memory != nullptr ? getSize(*memory) : 0;
1113 });
1114
1115 for (size_t i = 0; i < request.inputs.size(); ++i) {
1116 const auto& input = request.inputs[i];
1117 auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false);
1118 if (!result.has_value()) {
1119 return error() << std::move(result).error() << " for input RequestArgument " << i;
1120 }
1121 version = combineVersions(version, result.value());
1122 }
1123 for (size_t i = 0; i < request.outputs.size(); ++i) {
1124 const auto& output = request.outputs[i];
1125 auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true);
1126 if (!result.has_value()) {
1127 return error() << std::move(result).error() << " for output RequestArgument " << i;
1128 }
1129 version = combineVersions(version, result.value());
1130 }
1131
1132 return version;
1133 }
1134
validateOptionalTimePoint(const OptionalTimePoint & optionalTimePoint)1135 Result<Version> validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) {
1136 if (optionalTimePoint.has_value()) {
1137 NN_VALIDATE_GE(optionalTimePoint->time_since_epoch().count(), 0);
1138 }
1139 // An omitted time point is the default value, so it is implicitly valid for all versions.
1140 return !optionalTimePoint.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R;
1141 }
1142
validateOptionalTimeoutDuration(const OptionalDuration & optionalTimeoutDuration)1143 Result<Version> validateOptionalTimeoutDuration(const OptionalDuration& optionalTimeoutDuration) {
1144 if (optionalTimeoutDuration.has_value()) {
1145 NN_VALIDATE_GE(optionalTimeoutDuration->count(), 0);
1146 }
1147 // An omitted duration is the default value, so it is implicitly valid for all versions.
1148 return !optionalTimeoutDuration.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R;
1149 }
1150
validateCacheToken(const CacheToken & cacheToken)1151 Result<Version> validateCacheToken(const CacheToken& cacheToken) {
1152 // A CacheToken of 0 is the default value, so it is implicitly valid for all versions.
1153 constexpr auto kDefaultCacheToken = CacheToken{};
1154 return cacheToken == kDefaultCacheToken ? Version::ANDROID_OC_MR1 : Version::ANDROID_Q;
1155 }
1156
validateSyncFence(const SyncFence & syncFence)1157 Result<Version> validateSyncFence(const SyncFence& syncFence) {
1158 // The absence of a sync fence is implicitly valid for all versions.
1159 if (!syncFence.hasFd()) {
1160 return Version::ANDROID_OC_MR1;
1161 }
1162 NN_VALIDATE_GE(syncFence.getFd(), 0);
1163 return Version::ANDROID_R;
1164 }
1165
validateRequestArgumentsForModel(const std::vector<Request::Argument> & requestArguments,const std::vector<uint32_t> & operandIndexes,const std::vector<Operand> & operands,bool isOutput,bool allowUnspecifiedOutput)1166 Result<Version> validateRequestArgumentsForModel(
1167 const std::vector<Request::Argument>& requestArguments,
1168 const std::vector<uint32_t>& operandIndexes, const std::vector<Operand>& operands,
1169 bool isOutput, bool allowUnspecifiedOutput) {
1170 auto version = Version::ANDROID_OC_MR1;
1171 // The request should specify as many arguments as were described in the model.
1172 const std::string_view type = isOutput ? "output" : "input";
1173 const size_t requestArgumentCount = requestArguments.size();
1174 NN_VALIDATE_EQ(requestArgumentCount, operandIndexes.size())
1175 << "Request specifies " << requestArgumentCount << " " << type << "s but the model has "
1176 << operandIndexes.size();
1177 for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
1178 requestArgumentIndex++) {
1179 const Request::Argument& requestArgument = requestArguments[requestArgumentIndex];
1180 // Get the operand index for this argument. We extract it from the list
1181 // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
1182 // We assume in this function that the model has been validated already.
1183 const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
1184 const Operand& operand = operands[operandIndex];
1185 if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) {
1186 const bool isExtensionType = isExtension(operand.type);
1187 // If the argument specified a dimension, validate it.
1188 uint32_t modelRank = operand.dimensions.size();
1189 uint32_t requestRank = requestArgument.dimensions.size();
1190 if (requestRank == 0) {
1191 // NOTE: validateRequestArguments cannot validate unknown tensor rank with
1192 // extension operand type.
1193 if (!isExtensionType && !isNonExtensionScalar(operand.type)) {
1194 if (modelRank <= 0) {
1195 NN_VALIDATE(isOutput)
1196 << "Model has unknown input rank but the request does not "
1197 "specify the rank.";
1198 NN_VALIDATE(allowUnspecifiedOutput)
1199 << "Model has unknown output rank and request does not specify it.";
1200 // Unspecified output dimensions introduced in Android Q.
1201 version = combineVersions(version, Version::ANDROID_Q);
1202 }
1203 }
1204 // Validate that all the dimensions are specified in the model.
1205 for (size_t i = 0; i < modelRank; i++) {
1206 if (operand.dimensions[i] == 0) {
1207 NN_VALIDATE(isOutput && allowUnspecifiedOutput)
1208 << "Model has dimension " << i
1209 << " set to 0 but the request does not specify the dimension.";
1210 // Unspecified output dimensions introduced in Android Q.
1211 version = combineVersions(version, Version::ANDROID_Q);
1212 }
1213 }
1214 } else {
1215 NN_VALIDATE(modelRank == 0 || requestRank == modelRank)
1216 << "Request " << type << " " << requestArgumentIndex
1217 << " has number of dimensions (" << requestRank
1218 << ") different than the model's (" << modelRank << ")";
1219 for (size_t i = 0; i < requestRank; i++) {
1220 NN_VALIDATE(modelRank == 0 || operand.dimensions[i] == 0 ||
1221 requestArgument.dimensions[i] == operand.dimensions[i])
1222 << "Request " << type << " " << requestArgumentIndex
1223 << " has dimension " << i << " of " << requestArgument.dimensions[i]
1224 << " different than the model's " << operand.dimensions[i];
1225 if (requestArgument.dimensions[i] == 0) {
1226 NN_VALIDATE(isOutput && allowUnspecifiedOutput)
1227 << "Request " << type << " " << requestArgumentIndex
1228 << " has dimension " << i << " of zero";
1229 // Unspecified output dimensions introduced in Android Q.
1230 version = combineVersions(version, Version::ANDROID_Q);
1231 }
1232 }
1233 }
1234 // NOTE: validateRequestArguments cannot validate DataLocation::length
1235 // with extension operand type.
1236 if (!isExtensionType && requestArgument.location.length != 0) {
1237 const auto dimensions =
1238 NN_TRY(combineDimensions(operand.dimensions, requestArgument.dimensions));
1239 const size_t expectedLength = getNonExtensionSize(operand.type, dimensions).value();
1240 if (expectedLength != 0) {
1241 NN_VALIDATE_EQ(requestArgument.location.length, expectedLength)
1242 << "Request " << type << " " << requestArgumentIndex
1243 << " expected a size of " << expectedLength << " but got "
1244 << requestArgument.location.length;
1245 }
1246 }
1247 }
1248 }
1249 return version;
1250 }
1251
validateRequestForModelImpl(const Request & request,const Model & model,bool allowUnspecifiedOutput)1252 Result<Version> validateRequestForModelImpl(const Request& request, const Model& model,
1253 bool allowUnspecifiedOutput) {
1254 auto version = NN_TRY(validateRequest(request));
1255 version = combineVersions(version, NN_TRY(validateModel(model)));
1256 version = combineVersions(version,
1257 NN_TRY(validateRequestArgumentsForModel(
1258 request.inputs, model.main.inputIndexes, model.main.operands,
1259 /*isOutput=*/false, /*allowUnspecifiedOutput=*/true)));
1260 version = combineVersions(
1261 version, NN_TRY(validateRequestArgumentsForModel(
1262 request.outputs, model.main.outputIndexes, model.main.operands,
1263 /*isOutput=*/true, allowUnspecifiedOutput)));
1264 return version;
1265 }
1266
validateMemoryDescImpl(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)1267 Result<Version> validateMemoryDescImpl(
1268 const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
1269 const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
1270 const std::function<const Model*(const SharedPreparedModel&)>& getModel,
1271 std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
1272 NN_VALIDATE(!preparedModels.empty());
1273 NN_VALIDATE(!inputRoles.empty() || !outputRoles.empty());
1274
1275 std::set<PreparedModelRole> roles;
1276 std::vector<nn::Operand> operands;
1277 operands.reserve(inputRoles.size() + outputRoles.size());
1278 for (const auto& role : inputRoles) {
1279 NN_VALIDATE_LT(role.modelIndex, preparedModels.size());
1280 const auto& preparedModel = preparedModels[role.modelIndex];
1281 NN_VALIDATE(preparedModel != nullptr);
1282 const auto* model = getModel(preparedModel);
1283 NN_VALIDATE(model != nullptr);
1284 const auto& inputIndexes = model->main.inputIndexes;
1285 NN_VALIDATE_LT(role.ioIndex, inputIndexes.size());
1286 NN_VALIDATE_GT(role.probability, 0.0f);
1287 NN_VALIDATE_LE(role.probability, 1.0f);
1288 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
1289 NN_VALIDATE(success);
1290 operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
1291 }
1292 for (const auto& role : outputRoles) {
1293 NN_VALIDATE_LT(role.modelIndex, preparedModels.size());
1294 const auto& preparedModel = preparedModels[role.modelIndex];
1295 NN_VALIDATE(preparedModel != nullptr);
1296 const auto* model = getModel(preparedModel);
1297 NN_VALIDATE(model != nullptr);
1298 const auto& outputIndexes = model->main.outputIndexes;
1299 NN_VALIDATE_LT(role.ioIndex, outputIndexes.size());
1300 NN_VALIDATE_GT(role.probability, 0.0f);
1301 NN_VALIDATE_LE(role.probability, 1.0f);
1302 const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
1303 NN_VALIDATE(success);
1304 operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
1305 }
1306
1307 CHECK(!operands.empty());
1308 const auto opType = operands.front().type;
1309
1310 Dimensions dimensions = desc.dimensions;
1311 for (const auto& operand : operands) {
1312 NN_VALIDATE_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type;
1313 NN_VALIDATE_EQ(operand.scale, operands.front().scale);
1314 NN_VALIDATE_EQ(operand.zeroPoint, operands.front().zeroPoint);
1315 // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
1316 if (!isExtension(opType)) {
1317 NN_VALIDATE_EQ(operand.extraParams, operands.front().extraParams)
1318 << operand.extraParams << " vs " << operands.front().extraParams;
1319 }
1320 dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions));
1321 }
1322
1323 // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
1324 if (!isExtension(opType)) {
1325 NN_VALIDATE(!isNonExtensionScalar(opType) || dimensions.empty())
1326 << "invalid dimensions with scalar operand type.";
1327 }
1328
1329 if (preparedModelRoles != nullptr) {
1330 *preparedModelRoles = std::move(roles);
1331 }
1332 if (combinedOperand != nullptr) {
1333 *combinedOperand = operands.front();
1334 combinedOperand->dimensions = dimensions;
1335 }
1336 return Version::ANDROID_R;
1337 }
1338
1339 class OperationValidationContext : public IOperationValidationContext {
1340 DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
1341
1342 public:
OperationValidationContext(const char * operationName,const std::vector<uint32_t> & inputIndexes,const std::vector<uint32_t> & outputIndexes,const std::vector<Operand> & operands)1343 OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes,
1344 const std::vector<uint32_t>& outputIndexes,
1345 const std::vector<Operand>& operands)
1346 : operationName(operationName),
1347 inputIndexes(inputIndexes),
1348 outputIndexes(outputIndexes),
1349 operands(operands) {}
1350
1351 const char* getOperationName() const override;
1352
1353 uint32_t getNumInputs() const override;
1354 OperandType getInputType(uint32_t index) const override;
1355 Shape getInputShape(uint32_t index) const override;
1356 const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override;
1357
1358 uint32_t getNumOutputs() const override;
1359 OperandType getOutputType(uint32_t index) const override;
1360 Shape getOutputShape(uint32_t index) const override;
1361
1362 private:
1363 const Operand* getInputOperand(uint32_t index) const;
1364 const Operand* getOutputOperand(uint32_t index) const;
1365
1366 const char* operationName;
1367 const std::vector<uint32_t>& inputIndexes;
1368 const std::vector<uint32_t>& outputIndexes;
1369 const std::vector<Operand>& operands;
1370 };
1371
getOperationName() const1372 const char* OperationValidationContext::getOperationName() const {
1373 return operationName;
1374 }
1375
getInputOperand(uint32_t index) const1376 const Operand* OperationValidationContext::getInputOperand(uint32_t index) const {
1377 return &operands.at(inputIndexes.at(index));
1378 }
1379
getOutputOperand(uint32_t index) const1380 const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const {
1381 return &operands.at(outputIndexes.at(index));
1382 }
1383
getNumInputs() const1384 uint32_t OperationValidationContext::getNumInputs() const {
1385 auto count = inputIndexes.size();
1386 CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1387 return static_cast<uint32_t>(count);
1388 }
1389
getNumOutputs() const1390 uint32_t OperationValidationContext::getNumOutputs() const {
1391 auto count = outputIndexes.size();
1392 CHECK_LE(count, std::numeric_limits<uint32_t>::max());
1393 return static_cast<uint32_t>(count);
1394 }
1395
getInputType(uint32_t index) const1396 OperandType OperationValidationContext::getInputType(uint32_t index) const {
1397 return getInputOperand(index)->type;
1398 }
1399
getInputShape(uint32_t index) const1400 Shape OperationValidationContext::getInputShape(uint32_t index) const {
1401 const Operand* operand = getInputOperand(index);
1402 return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1403 operand->extraParams};
1404 }
1405
getInputExtraParams(uint32_t index) const1406 const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const {
1407 return getInputOperand(index)->extraParams;
1408 }
1409
getOutputType(uint32_t index) const1410 OperandType OperationValidationContext::getOutputType(uint32_t index) const {
1411 return getOutputOperand(index)->type;
1412 }
1413
getOutputShape(uint32_t index) const1414 Shape OperationValidationContext::getOutputShape(uint32_t index) const {
1415 const Operand* operand = getOutputOperand(index);
1416 return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
1417 operand->extraParams};
1418 }
1419
1420 // TODO(b/169345292): reduce the duplicate validation here
1421
validateOperandSymmPerChannelQuantParamsImpl(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)1422 Result<void> validateOperandSymmPerChannelQuantParamsImpl(
1423 const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
1424 const char* tag) {
1425 if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
1426 NN_VALIDATE_FAIL();
1427 }
1428
1429 NN_VALIDATE_LT(channelQuant.channelDim, operand.dimensions.size()) << tag;
1430 NN_VALIDATE(!channelQuant.scales.empty()) << tag;
1431 NN_VALIDATE_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag;
1432 NN_VALIDATE_NE(operand.dimensions[channelQuant.channelDim], 0u)
1433 << tag << " channel dimension " << channelQuant.channelDim << " is underspecified";
1434 for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) {
1435 NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]";
1436 }
1437 return {};
1438 }
1439
validateScalarDimensions(const Operand & type,const char * tag)1440 Result<void> validateScalarDimensions(const Operand& type, const char* tag) {
1441 NN_VALIDATE(type.dimensions.empty()) << tag << " invalid dimensions for scalar type";
1442 return {};
1443 }
1444
validateQuant8AsymmParams(const Operand & type,const char * tag)1445 Result<void> validateQuant8AsymmParams(const Operand& type, const char* tag) {
1446 NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 255)
1447 << tag << " invalid zeroPoint: " << type.zeroPoint;
1448 NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1449 return {};
1450 }
1451
validateQuant8AsymmSignedParams(const Operand & type,const char * tag)1452 Result<void> validateQuant8AsymmSignedParams(const Operand& type, const char* tag) {
1453 NN_VALIDATE(-128 <= type.zeroPoint && type.zeroPoint <= 127)
1454 << tag << " invalid zeroPoint: " << type.zeroPoint;
1455 NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1456 return {};
1457 }
1458
validateQuant8SymmParams(const Operand & type,const char * tag)1459 Result<void> validateQuant8SymmParams(const Operand& type, const char* tag) {
1460 NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint;
1461 NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1462 return {};
1463 }
1464
validateQuant16AsymmParams(const Operand & type,const char * tag)1465 Result<void> validateQuant16AsymmParams(const Operand& type, const char* tag) {
1466 NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 65535)
1467 << tag << " invalid zeroPoint: " << type.zeroPoint;
1468 NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1469 return {};
1470 }
1471
validateQuantSymmParams(const Operand & type,const char * tag)1472 Result<void> validateQuantSymmParams(const Operand& type, const char* tag) {
1473 NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1474 NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale";
1475 return {};
1476 }
1477
validateNoQuantParams(const Operand & type,const char * tag)1478 Result<void> validateNoQuantParams(const Operand& type, const char* tag) {
1479 NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero";
1480 NN_VALIDATE_EQ(type.scale, 0.0f) << tag << " scale is not zero";
1481 return {};
1482 }
1483
validateTensorDimensions(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)1484 Result<void> validateTensorDimensions(
1485 const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo,
1486 const char* tag, bool allowPartial) {
1487 if (!allowPartial) {
1488 NN_VALIDATE(!type.dimensions.empty()) << tag << " invalid operand dimensions";
1489 }
1490 uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize
1491 : getNonExtensionSize(type.type);
1492 constexpr uint64_t kMaxSize = std::numeric_limits<uint32_t>::max();
1493 for (size_t i = 0; i < type.dimensions.size(); i++) {
1494 if (!allowPartial) {
1495 NN_VALIDATE_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions";
1496 }
1497 if (type.dimensions[i] != 0) {
1498 size *= type.dimensions[i];
1499 NN_VALIDATE_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize;
1500 }
1501 }
1502 return {};
1503 }
1504
validateOperandTypeImpl(const Operand & type,const Extension::OperandTypeInformation * const extensionOperandTypeInfo,const char * tag,bool allowPartial)1505 Result<void> validateOperandTypeImpl(
1506 const Operand& type,
1507 const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag,
1508 bool allowPartial) {
1509 if (isExtension(type.type)) {
1510 NN_VALIDATE(extensionOperandTypeInfo != nullptr);
1511 if (extensionOperandTypeInfo->isTensor) {
1512 NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1513 } else {
1514 NN_TRY(validateScalarDimensions(type, tag));
1515 }
1516 return validateNoQuantParams(type, tag);
1517 }
1518
1519 NN_VALIDATE(extensionOperandTypeInfo == nullptr);
1520 NN_TRY(validateOperandType(type.type));
1521
1522 if (isNonExtensionScalar(type.type)) {
1523 NN_TRY(validateScalarDimensions(type, tag));
1524 if (type.type != OperandType::OEM) { // Historically, we have allowed OEM types
1525 // to use quantization parameters.
1526 NN_TRY(validateNoQuantParams(type, tag));
1527 }
1528 } else {
1529 NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial));
1530 if (type.type == OperandType::TENSOR_QUANT8_ASYMM) {
1531 NN_TRY(validateQuant8AsymmParams(type, tag));
1532 } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1533 NN_TRY(validateQuant8AsymmSignedParams(type, tag));
1534 } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) {
1535 NN_TRY(validateQuant8SymmParams(type, tag));
1536 } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) {
1537 NN_TRY(validateQuant16AsymmParams(type, tag));
1538 } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) {
1539 NN_TRY(validateQuantSymmParams(type, tag));
1540 } else if (type.type == OperandType::TENSOR_INT32 ||
1541 type.type == OperandType::TENSOR_OEM_BYTE) {
1542 // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters.
1543 // Historically, we have allowed OEM types to use quantization parameters.
1544 } else {
1545 NN_TRY(validateNoQuantParams(type, tag));
1546 }
1547 }
1548
1549 return {};
1550 }
1551
validateOperandListImpl(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)1552 Result<void> validateOperandListImpl(const std::vector<uint32_t>& list, size_t operandCount,
1553 const char* tag) {
1554 for (size_t i = 0; i < list.size(); i++) {
1555 NN_VALIDATE_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = "
1556 << list[i] << ", operandCount " << operandCount;
1557 }
1558 return {};
1559 }
1560
validateOperationOperandTypes(const std::vector<Operand> & operands,const std::vector<uint32_t> & inputIndexes,const std::vector<OperandType> & inExpectedTypes,const std::vector<uint32_t> & outputIndexes,const std::vector<OperandType> & outExpectedInTypes)1561 Result<void> validateOperationOperandTypes(const std::vector<Operand>& operands,
1562 const std::vector<uint32_t>& inputIndexes,
1563 const std::vector<OperandType>& inExpectedTypes,
1564 const std::vector<uint32_t>& outputIndexes,
1565 const std::vector<OperandType>& outExpectedInTypes) {
1566 NN_VALIDATE_EQ(inputIndexes.size(), inExpectedTypes.size())
1567 << "Wrong operand count: expected " << inputIndexes.size() << " inputs, got "
1568 << inputIndexes.size() << " inputs";
1569 NN_VALIDATE_EQ(outputIndexes.size(), outExpectedInTypes.size())
1570 << "Wrong operand count: expected " << outputIndexes.size() << " outputs, got "
1571 << outputIndexes.size() << " outputs";
1572 for (size_t i = 0; i < inputIndexes.size(); i++) {
1573 NN_VALIDATE_EQ(operands[inputIndexes[i]].type, inExpectedTypes[i])
1574 << "Invalid input tensor type " << operands[inputIndexes[i]].type << " for input "
1575 << i << ", expected " << inExpectedTypes[i];
1576 }
1577 for (size_t i = 0; i < outputIndexes.size(); i++) {
1578 NN_VALIDATE_EQ(operands[outputIndexes[i]].type, outExpectedInTypes[i])
1579 << "Invalid output tensor type " << operands[outputIndexes[i]].type << " for input "
1580 << i << ", expected " << outExpectedInTypes[i];
1581 }
1582
1583 return {};
1584 }
1585
validateSubgraphReference(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1586 Result<void> validateSubgraphReference(const std::vector<Model::Subgraph>& subgraphs,
1587 const Operand& modelOperand) {
1588 NN_VALIDATE_EQ(modelOperand.type, OperandType::SUBGRAPH)
1589 << "Unexpected operand type: " << modelOperand.type;
1590 NN_VALIDATE_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference";
1591 return {};
1592 }
getSubgraph(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1593 const Model::Subgraph& getSubgraph(const std::vector<Model::Subgraph>& subgraphs,
1594 const Operand& modelOperand) {
1595 return subgraphs.at(modelOperand.location.offset);
1596 }
getInputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1597 uint32_t getInputCount(const std::vector<Model::Subgraph>& subgraphs, const Operand& modelOperand) {
1598 return getSubgraph(subgraphs, modelOperand).inputIndexes.size();
1599 }
getOutputCount(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand)1600 uint32_t getOutputCount(const std::vector<Model::Subgraph>& subgraphs,
1601 const Operand& modelOperand) {
1602 return getSubgraph(subgraphs, modelOperand).outputIndexes.size();
1603 }
getInputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1604 const Operand& getInputOperand(const std::vector<Model::Subgraph>& subgraphs,
1605 const Operand& modelOperand, uint32_t index) {
1606 const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1607 return subgraph.operands.at(subgraph.inputIndexes.at(index));
1608 }
getOutputOperand(const std::vector<Model::Subgraph> & subgraphs,const Operand & modelOperand,uint32_t index)1609 const Operand& getOutputOperand(const std::vector<Model::Subgraph>& subgraphs,
1610 const Operand& modelOperand, uint32_t index) {
1611 const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand);
1612 return subgraph.operands.at(subgraph.outputIndexes.at(index));
1613 }
1614
1615 // Checks if two operands have the same types, ranks (if specified), dimensions
1616 // (if specified), scales, zeroPoints, and extraParams.
compatible(const Operand & a,const Operand & b)1617 Result<void> compatible(const Operand& a, const Operand& b) {
1618 NN_VALIDATE_EQ(a.type, b.type) << a.type << " != " << b.type;
1619 if (!a.dimensions.empty() && !b.dimensions.empty()) {
1620 NN_VALIDATE_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions";
1621 for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) {
1622 if (a.dimensions[i] != 0 && b.dimensions[i] != 0) {
1623 NN_VALIDATE_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions";
1624 }
1625 }
1626 }
1627 NN_VALIDATE_EQ(a.scale, b.scale);
1628 NN_VALIDATE_EQ(a.zeroPoint, b.zeroPoint);
1629 NN_VALIDATE_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams;
1630 return {};
1631 }
1632
validateConditionOperand(const Operand & operand)1633 Result<void> validateConditionOperand(const Operand& operand) {
1634 NN_VALIDATE_EQ(operand.type, OperandType::TENSOR_BOOL8)
1635 << "Unexpected condition operand type: " << operand.type;
1636 NN_VALIDATE_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
1637 NN_VALIDATE_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
1638 return {};
1639 }
1640
validateIfOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1641 Result<Version> validateIfOperation(const std::vector<uint32_t>& inputs,
1642 const std::vector<uint32_t>& outputs,
1643 const std::vector<Operand>& operands,
1644 const std::vector<Model::Subgraph>& subgraphs) {
1645 namespace op = operation_if;
1646 NN_VALIDATE_GE(inputs.size(), 3u) << "IF must have at least 3 inputs";
1647 NN_VALIDATE_GE(outputs.size(), 1u) << "IF must have at least 1 output";
1648 auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result<void> {
1649 auto result = validateSubgraphReference(subgraphs, branchModelOperand);
1650 if (!result.has_value()) {
1651 return error() << std::move(result).error()
1652 << " -- Operand is not a valid subgraph reference";
1653 }
1654 const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand);
1655 const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand);
1656 NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + branchModelInputCount);
1657 NN_VALIDATE_EQ(outputs.size(), branchModelOutputCount);
1658 for (uint32_t i = 0; i < branchModelInputCount; ++i) {
1659 const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i);
1660 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1661 NN_TRY(compatible(innerOperand, outerOperand));
1662 }
1663 for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
1664 const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i);
1665 const Operand& outerOperand = operands[outputs[i]];
1666 NN_TRY(compatible(innerOperand, outerOperand));
1667 }
1668 return {};
1669 };
1670 auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]);
1671 if (!result.has_value()) {
1672 return error() << std::move(result).error() << " for IF condition operand";
1673 }
1674 result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]);
1675 if (!result.has_value()) {
1676 return error() << std::move(result).error() << " for IF then model";
1677 }
1678 result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]);
1679 if (!result.has_value()) {
1680 return error() << std::move(result).error() << " for IF else model";
1681 }
1682 return Version::ANDROID_R;
1683 }
1684
validateControlFlowOperandUnknownSize(const Operand & operand)1685 Result<Version> validateControlFlowOperandUnknownSize(const Operand& operand) {
1686 if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) {
1687 // 1.3 HAL (corresponding to Version::ANDROID_R) does not support CF operations with
1688 // operands of unknown size. See http://b/132458982#comment63.
1689 return Version::CURRENT_RUNTIME;
1690 }
1691 return Version::ANDROID_R;
1692 }
1693
validateWhileOperation(const std::vector<uint32_t> & inputs,const std::vector<uint32_t> & outputs,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1694 Result<Version> validateWhileOperation(const std::vector<uint32_t>& inputs,
1695 const std::vector<uint32_t>& outputs,
1696 const std::vector<Operand>& operands,
1697 const std::vector<Model::Subgraph>& subgraphs) {
1698 // Let the loop have
1699 // - m >= 1 input-output operands,
1700 // - k >= 0 state-only operands, and
1701 // - n >= 0 input-only operands.
1702 // Then
1703 // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
1704 // - the condition model has (m + k + n) inputs and 1 output.
1705 // - the body model has (m + k + n) inputs and (m + k) outputs.
1706 namespace op = operation_while;
1707 NN_VALIDATE_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs";
1708 NN_VALIDATE_GE(outputs.size(), 1u) << "WHILE must have at least 1 output";
1709 auto validateCondOperand = [&](const Operand& condModelOperand) -> Result<Version> {
1710 Version version = Version::ANDROID_R;
1711 auto result = validateSubgraphReference(subgraphs, condModelOperand);
1712 if (!result.has_value()) {
1713 return error() << std::move(result).error()
1714 << " -- Operand is not a valid subgraph reference";
1715 }
1716 const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand);
1717 const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand);
1718 NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + condModelInputCount);
1719 NN_VALIDATE_EQ(condModelOutputCount, 1u);
1720 for (uint32_t i = 0; i < condModelInputCount; ++i) {
1721 const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i);
1722 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1723 NN_TRY(compatible(innerOperand, outerOperand));
1724 version = combineVersions(version,
1725 NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1726 version = combineVersions(version,
1727 NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1728 }
1729 NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0)));
1730 return version;
1731 };
1732 auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result<Version> {
1733 Version version = Version::ANDROID_R;
1734 auto result = validateSubgraphReference(subgraphs, bodyModelOperand);
1735 if (!result.has_value()) {
1736 return error() << std::move(result).error()
1737 << " -- Operand is not a valid subgraph reference";
1738 }
1739 const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand);
1740 const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand);
1741 NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount);
1742 NN_VALIDATE_GE(bodyModelOutputCount, outputs.size());
1743 NN_VALIDATE_GE(bodyModelInputCount, bodyModelOutputCount);
1744 const uint32_t inputOutputCount = outputs.size();
1745 const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
1746 const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
1747 for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
1748 const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1749 const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
1750 NN_TRY(compatible(innerOperand, outerOperand));
1751 version = combineVersions(version,
1752 NN_TRY(validateControlFlowOperandUnknownSize(innerOperand)));
1753 version = combineVersions(version,
1754 NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1755 }
1756 for (uint32_t i = 0; i < inputOutputCount; ++i) {
1757 const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1758 const Operand& outerOperand = operands[outputs[i]];
1759 NN_TRY(compatible(innerOperand, outerOperand));
1760 version = combineVersions(version,
1761 NN_TRY(validateControlFlowOperandUnknownSize(outerOperand)));
1762 }
1763 for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
1764 const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i);
1765 const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i);
1766 NN_TRY(compatible(inputOperand, outputOperand));
1767 version = combineVersions(version,
1768 NN_TRY(validateControlFlowOperandUnknownSize(outputOperand)));
1769 }
1770 return version;
1771 };
1772 auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]);
1773 if (!result.has_value()) {
1774 return error() << std::move(result).error() << " for WHILE condition model";
1775 }
1776 auto version = result.value();
1777 result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]);
1778 if (!result.has_value()) {
1779 return error() << std::move(result).error() << " for WHILE body model";
1780 }
1781 version = combineVersions(version, result.value());
1782 return version;
1783 }
1784
validateOperationButNotOperandsImpl(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)1785 Result<Version> validateOperationButNotOperandsImpl(const Operation& operation,
1786 const std::vector<Operand>& operands,
1787 const std::vector<Model::Subgraph>& subgraphs) {
1788 const auto opType = operation.type;
1789 const auto& inputIndexes = operation.inputs;
1790 const auto& outputIndexes = operation.outputs;
1791
1792 NN_TRY(validateOperandListImpl(inputIndexes, operands.size(),
1793 "ANeuralNetworksModel_addOperation inputs"));
1794 NN_TRY(validateOperandListImpl(outputIndexes, operands.size(),
1795 "ANeuralNetworksModel_addOperation outputs"));
1796
1797 if (isExtension(opType)) {
1798 // There is no other validation we can do for an extension operation.
1799 return Version::ANDROID_Q;
1800 }
1801
1802 auto invalidInOutNumberMessage = [opType, &inputIndexes, &outputIndexes](int expIn,
1803 int expOut) {
1804 std::ostringstream os;
1805 os << "Invalid number of input operands (" << inputIndexes.size() << ", expected " << expIn
1806 << ") or output operands (" << outputIndexes.size() << ", expected " << expOut
1807 << ") for operation " << opType;
1808 return os.str();
1809 };
1810
1811 switch (opType) {
1812 case OperationType::OEM_OPERATION: {
1813 return Version::ANDROID_OC_MR1;
1814 }
1815 case OperationType::RESHAPE: {
1816 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
1817 << invalidInOutNumberMessage(2, 1);
1818 auto inputType = operands[inputIndexes[0]].type;
1819 Version version;
1820 std::vector<OperandType> inExpectedTypes;
1821 std::vector<OperandType> outExpectedTypes;
1822 if (inputType == OperandType::TENSOR_FLOAT32) {
1823 version = Version::ANDROID_OC_MR1;
1824 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32};
1825 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1826 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1827 version = Version::ANDROID_Q;
1828 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32};
1829 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1830 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1831 version = Version::ANDROID_OC_MR1;
1832 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32};
1833 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1834 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1835 version = Version::ANDROID_R;
1836 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
1837 OperandType::TENSOR_INT32};
1838 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1839 } else {
1840 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
1841 }
1842 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
1843 NN_VALIDATE_LE(inputRank, 4u)
1844 << "Unsupported input tensor rank for operation " << opType;
1845 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1846 outputIndexes, outExpectedTypes));
1847 return version;
1848 }
1849 case OperationType::DEPTH_TO_SPACE: {
1850 NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
1851 outputIndexes.size() == 1)
1852 << "Invalid number of input operands (" << inputIndexes.size()
1853 << ", expected 3 or 2) or output operands (" << outputIndexes.size()
1854 << ", expected 1) for operation " << opType;
1855 auto inputType = operands[inputIndexes[0]].type;
1856 Version version;
1857 std::vector<OperandType> inExpectedTypes;
1858 std::vector<OperandType> outExpectedTypes;
1859 if (inputType == OperandType::TENSOR_FLOAT32) {
1860 version = Version::ANDROID_OC_MR1;
1861 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
1862 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1863 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1864 version = Version::ANDROID_Q;
1865 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
1866 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1867 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1868 version = Version::ANDROID_OC_MR1;
1869 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
1870 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1871 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1872 version = Version::ANDROID_R;
1873 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
1874 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1875 } else {
1876 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
1877 }
1878 if (inputIndexes.size() == 3) {
1879 inExpectedTypes.push_back(OperandType::BOOL);
1880 version = combineVersions(version, Version::ANDROID_Q);
1881 } else {
1882 version = combineVersions(version, Version::ANDROID_OC_MR1);
1883 }
1884 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1885 outputIndexes, outExpectedTypes));
1886 return version;
1887 }
1888 case OperationType::SPACE_TO_DEPTH: {
1889 NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
1890 outputIndexes.size() == 1)
1891 << "Invalid number of input operands (" << inputIndexes.size()
1892 << ", expected 3 or 2) or output operands (" << outputIndexes.size()
1893 << ", expected 1) for operation " << opType;
1894 auto inputType = operands[inputIndexes[0]].type;
1895 Version version;
1896 std::vector<OperandType> inExpectedTypes;
1897 std::vector<OperandType> outExpectedTypes;
1898 if (inputType == OperandType::TENSOR_FLOAT32) {
1899 version = Version::ANDROID_OC_MR1;
1900 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
1901 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
1902 } else if (inputType == OperandType::TENSOR_FLOAT16) {
1903 version = Version::ANDROID_Q;
1904 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32};
1905 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
1906 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1907 version = Version::ANDROID_OC_MR1;
1908 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
1909 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
1910 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1911 version = Version::ANDROID_R;
1912 inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32};
1913 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
1914 } else {
1915 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
1916 }
1917 if (inputIndexes.size() == 3) {
1918 inExpectedTypes.push_back(OperandType::BOOL);
1919 version = combineVersions(version, Version::ANDROID_Q);
1920 } else {
1921 version = combineVersions(version, Version::ANDROID_OC_MR1);
1922 }
1923 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1924 outputIndexes, outExpectedTypes));
1925 return version;
1926 }
1927 case OperationType::EMBEDDING_LOOKUP: {
1928 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
1929 << invalidInOutNumberMessage(2, 1);
1930 auto inputType = operands[inputIndexes[1]].type;
1931 NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
1932 inputType == OperandType::TENSOR_FLOAT32 ||
1933 inputType == OperandType::TENSOR_INT32 ||
1934 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
1935 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
1936 << "Unsupported input tensor type for operation " << opType;
1937 Version version;
1938 std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32, inputType};
1939 std::vector<OperandType> outExpectedTypes = {inputType};
1940 if (inputType == OperandType::TENSOR_FLOAT16 ||
1941 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
1942 version = Version::ANDROID_R;
1943 } else if (inputType == OperandType::TENSOR_INT32 ||
1944 inputType == OperandType::TENSOR_QUANT8_ASYMM) {
1945 version = Version::ANDROID_Q;
1946 } else {
1947 version = Version::ANDROID_OC_MR1;
1948 }
1949 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1950 outputIndexes, outExpectedTypes));
1951 return version;
1952 }
1953 case OperationType::HASHTABLE_LOOKUP: {
1954 NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 2)
1955 << invalidInOutNumberMessage(3, 2);
1956 auto inputType = operands[inputIndexes[2]].type;
1957 NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
1958 inputType == OperandType::TENSOR_INT32 ||
1959 inputType == OperandType::TENSOR_QUANT8_ASYMM)
1960 << "Unsupported input tensor type for operation " << opType;
1961 std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_INT32,
1962 OperandType::TENSOR_INT32, inputType};
1963 std::vector<OperandType> outExpectedTypes = {inputType,
1964 OperandType::TENSOR_QUANT8_ASYMM};
1965 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
1966 outputIndexes, outExpectedTypes));
1967 return Version::ANDROID_OC_MR1;
1968 }
1969 case OperationType::LSH_PROJECTION: {
1970 NN_VALIDATE(inputIndexes.size() == 4 && outputIndexes.size() == 1)
1971 << invalidInOutNumberMessage(4, 1);
1972 auto inputType = operands[inputIndexes[1]].type;
1973 NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
1974 inputType == OperandType::TENSOR_FLOAT32 ||
1975 inputType == OperandType::TENSOR_INT32 ||
1976 inputType == OperandType::TENSOR_QUANT8_ASYMM)
1977 << "Unsupported input tensor type for operation " << opType;
1978 auto hashType = operands[inputIndexes[0]].type;
1979 Version version;
1980 std::vector<OperandType> inExpectedTypes;
1981 if (hashType == OperandType::TENSOR_FLOAT16) {
1982 version = Version::ANDROID_Q;
1983 inExpectedTypes = {
1984 OperandType::TENSOR_FLOAT16,
1985 inputType,
1986 OperandType::TENSOR_FLOAT16,
1987 OperandType::INT32,
1988 };
1989 } else if (hashType == OperandType::TENSOR_FLOAT32) {
1990 version = Version::ANDROID_OC_MR1;
1991 inExpectedTypes = {
1992 OperandType::TENSOR_FLOAT32,
1993 inputType,
1994 OperandType::TENSOR_FLOAT32,
1995 OperandType::INT32,
1996 };
1997 } else {
1998 NN_VALIDATE_FAIL() << "Unsupported hash tensor type for operation " << opType;
1999 }
2000 std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
2001 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2002 outputIndexes, outExpectedTypes));
2003 return version;
2004 }
2005 case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM: {
2006 const uint32_t kNumOutputs = 2;
2007 const uint32_t kNumOutputsMerged = 1;
2008 const uint32_t kNumOutputsWithState = 6;
2009 const uint32_t kNumOutputsMergedWithState = 5;
2010 NN_VALIDATE(inputIndexes.size() == 61 &&
2011 (outputIndexes.size() == kNumOutputs ||
2012 outputIndexes.size() == kNumOutputsMerged ||
2013 outputIndexes.size() == kNumOutputsWithState ||
2014 outputIndexes.size() == kNumOutputsMergedWithState))
2015 << "Invalid number of input operands (" << inputIndexes.size()
2016 << ", expected 61) or output operands (" << outputIndexes.size()
2017 << ", expected 1, 2, 5 or 6) for operation " << opType;
2018
2019 std::vector<OperandType> inExpectedTypes;
2020 auto inputType = operands[inputIndexes[0]].type;
2021 NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
2022 inputType == OperandType::TENSOR_FLOAT16)
2023 << "Unsupported input tensor type for operation " << opType;
2024
2025 inExpectedTypes = {};
2026 for (int i = 0; i < 48; ++i) {
2027 inExpectedTypes.push_back(inputType);
2028 }
2029 inExpectedTypes.push_back(OperandType::INT32);
2030 inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
2031 ? OperandType::FLOAT32
2032 : OperandType::FLOAT16);
2033 inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32
2034 ? OperandType::FLOAT32
2035 : OperandType::FLOAT16);
2036 inExpectedTypes.push_back(OperandType::BOOL);
2037 inExpectedTypes.push_back(OperandType::BOOL);
2038 for (int i = 0; i < 8; ++i) {
2039 inExpectedTypes.push_back(inputType);
2040 }
2041
2042 Version version = Version::ANDROID_Q;
2043 if (outputIndexes.size() == kNumOutputsWithState ||
2044 outputIndexes.size() == kNumOutputsMergedWithState) {
2045 version = Version::ANDROID_R;
2046 }
2047 std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType);
2048 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2049 outputIndexes, outExpectedTypes));
2050 return version;
2051 }
2052 case OperationType::LSTM: {
2053 NN_VALIDATE((inputIndexes.size() == 23 || inputIndexes.size() == 27) &&
2054 outputIndexes.size() == 4)
2055 << "Invalid number of input operands (" << inputIndexes.size()
2056 << ", expected 23 or 27) or output operands (" << outputIndexes.size()
2057 << ", expected 4) for operation " << opType;
2058 std::vector<OperandType> inExpectedTypes;
2059 std::vector<OperandType> outExpectedTypes;
2060 auto inputType = operands[inputIndexes[0]].type;
2061 NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 ||
2062 inputType == OperandType::TENSOR_FLOAT16)
2063 << "Unsupported input tensor type for operation " << opType;
2064
2065 Version version = Version::ANDROID_OC_MR1;
2066 inExpectedTypes = {inputType, inputType, inputType, inputType, inputType,
2067 inputType, inputType, inputType, inputType, inputType,
2068 inputType, inputType, inputType, inputType, inputType,
2069 inputType, inputType, inputType, inputType, inputType,
2070 OperandType::INT32};
2071 if (inputType == OperandType::TENSOR_FLOAT32) {
2072 inExpectedTypes.push_back(OperandType::FLOAT32);
2073 inExpectedTypes.push_back(OperandType::FLOAT32);
2074 } else {
2075 version = Version::ANDROID_Q;
2076 inExpectedTypes.push_back(OperandType::FLOAT16);
2077 inExpectedTypes.push_back(OperandType::FLOAT16);
2078 }
2079
2080 outExpectedTypes = {inputType, inputType, inputType, inputType};
2081 if (inputIndexes.size() == 23) {
2082 version = combineVersions(version, Version::ANDROID_OC_MR1);
2083 } else {
2084 version = combineVersions(version, Version::ANDROID_Q);
2085 for (int i = 0; i < 4; ++i) {
2086 inExpectedTypes.push_back(inputType);
2087 }
2088 }
2089 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2090 outputIndexes, outExpectedTypes));
2091 return version;
2092 }
2093 case OperationType::QUANTIZED_16BIT_LSTM: {
2094 NN_VALIDATE(inputIndexes.size() == 15 && outputIndexes.size() == 2)
2095 << invalidInOutNumberMessage(15, 2);
2096 std::vector<OperandType> inExpectedTypes = {
2097 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2098 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2099 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2100 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM,
2101 OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32,
2102 OperandType::TENSOR_INT32, OperandType::TENSOR_INT32,
2103 OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_SYMM,
2104 OperandType::TENSOR_QUANT8_ASYMM};
2105 std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM,
2106 OperandType::TENSOR_QUANT8_ASYMM};
2107 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2108 outputIndexes, outExpectedTypes));
2109 return Version::ANDROID_Q;
2110 }
2111 case OperationType::RANDOM_MULTINOMIAL: {
2112 NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
2113 << invalidInOutNumberMessage(3, 1);
2114 OperandType inputType = operands[inputIndexes[0]].type;
2115 std::vector<OperandType> inExpectedTypes;
2116 if (inputType == OperandType::TENSOR_FLOAT32 ||
2117 inputType == OperandType::TENSOR_FLOAT16) {
2118 inExpectedTypes = {inputType, OperandType::INT32, OperandType::TENSOR_INT32};
2119 } else {
2120 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2121 }
2122 std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
2123 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2124 outputIndexes, outExpectedTypes));
2125 return Version::ANDROID_Q;
2126 }
2127 case OperationType::RNN: {
2128 NN_VALIDATE(inputIndexes.size() == 6 && outputIndexes.size() == 2)
2129 << invalidInOutNumberMessage(6, 2);
2130 OperandType inputType = operands[inputIndexes[0]].type;
2131 Version version;
2132 std::vector<OperandType> inExpectedTypes;
2133 std::vector<OperandType> outExpectedTypes;
2134 if (inputType == OperandType::TENSOR_FLOAT32) {
2135 version = Version::ANDROID_OC_MR1;
2136 inExpectedTypes = {
2137 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
2138 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
2139 OperandType::TENSOR_FLOAT32, OperandType::INT32,
2140 };
2141 outExpectedTypes = {
2142 OperandType::TENSOR_FLOAT32,
2143 OperandType::TENSOR_FLOAT32,
2144 };
2145 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2146 version = Version::ANDROID_Q;
2147 inExpectedTypes = {
2148 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
2149 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
2150 OperandType::TENSOR_FLOAT16, OperandType::INT32,
2151 };
2152 outExpectedTypes = {
2153 OperandType::TENSOR_FLOAT16,
2154 OperandType::TENSOR_FLOAT16,
2155 };
2156 } else {
2157 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2158 }
2159 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2160 outputIndexes, outExpectedTypes));
2161 return version;
2162 }
2163 case OperationType::SVDF: {
2164 NN_VALIDATE(inputIndexes.size() == 7 && outputIndexes.size() == 2)
2165 << invalidInOutNumberMessage(7, 2);
2166 Version version;
2167 OperandType inputType = operands[inputIndexes[0]].type;
2168 if (inputType == OperandType::TENSOR_FLOAT32) {
2169 version = Version::ANDROID_OC_MR1;
2170 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2171 version = Version::ANDROID_Q;
2172 } else {
2173 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2174 }
2175 std::vector<OperandType> inExpectedTypes = {
2176 inputType, inputType, inputType, inputType,
2177 inputType, OperandType::INT32, OperandType::INT32,
2178 };
2179 std::vector<OperandType> outExpectedTypes = {inputType, inputType};
2180 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2181 outputIndexes, outExpectedTypes));
2182 return version;
2183 }
2184 case OperationType::BATCH_TO_SPACE_ND: {
2185 NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) &&
2186 outputIndexes.size() == 1)
2187 << "Invalid number of input operands (" << inputIndexes.size()
2188 << ", expected 3 or 2) or output operands (" << outputIndexes.size()
2189 << ", expected 1) for operation " << opType;
2190 auto inputType = operands[inputIndexes[0]].type;
2191 Version version = Version::ANDROID_OC_MR1;
2192 std::vector<OperandType> inExpectedTypes;
2193 std::vector<OperandType> outExpectedTypes;
2194 if (inputType == OperandType::TENSOR_FLOAT32) {
2195 inExpectedTypes = {
2196 OperandType::TENSOR_FLOAT32,
2197 OperandType::TENSOR_INT32,
2198 };
2199 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2200 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2201 version = Version::ANDROID_Q;
2202 inExpectedTypes = {
2203 OperandType::TENSOR_FLOAT16,
2204 OperandType::TENSOR_INT32,
2205 };
2206 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2207 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
2208 inExpectedTypes = {
2209 OperandType::TENSOR_QUANT8_ASYMM,
2210 OperandType::TENSOR_INT32,
2211 };
2212 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
2213 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2214 version = Version::ANDROID_R;
2215 inExpectedTypes = {
2216 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
2217 OperandType::TENSOR_INT32,
2218 };
2219 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
2220 } else {
2221 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2222 }
2223 if (inputIndexes.size() == 3) {
2224 inExpectedTypes.push_back(OperandType::BOOL);
2225 version = combineVersions(version, Version::ANDROID_Q);
2226 } else {
2227 version = combineVersions(version, Version::ANDROID_P);
2228 }
2229 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2230 outputIndexes, outExpectedTypes));
2231 return version;
2232 }
2233 case OperationType::SPACE_TO_BATCH_ND: {
2234 NN_VALIDATE((inputIndexes.size() == 4 || inputIndexes.size() == 3) &&
2235 outputIndexes.size() == 1)
2236 << "Invalid number of input operands (" << inputIndexes.size()
2237 << ", expected 4 or 3) or output operands (" << outputIndexes.size()
2238 << ", expected 1) for operation " << opType;
2239 auto inputType = operands[inputIndexes[0]].type;
2240 Version version = Version::ANDROID_OC_MR1;
2241 std::vector<OperandType> inExpectedTypes;
2242 std::vector<OperandType> outExpectedTypes;
2243 if (inputType == OperandType::TENSOR_FLOAT32) {
2244 inExpectedTypes = {
2245 OperandType::TENSOR_FLOAT32,
2246 OperandType::TENSOR_INT32,
2247 OperandType::TENSOR_INT32,
2248 };
2249 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2250 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2251 version = Version::ANDROID_Q;
2252 inExpectedTypes = {
2253 OperandType::TENSOR_FLOAT16,
2254 OperandType::TENSOR_INT32,
2255 OperandType::TENSOR_INT32,
2256 };
2257 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2258 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
2259 if (operands[inputIndexes[0]].zeroPoint != 0) {
2260 version = Version::ANDROID_Q;
2261 }
2262 inExpectedTypes = {
2263 OperandType::TENSOR_QUANT8_ASYMM,
2264 OperandType::TENSOR_INT32,
2265 OperandType::TENSOR_INT32,
2266 };
2267 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
2268 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2269 version = Version::ANDROID_R;
2270 inExpectedTypes = {
2271 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
2272 OperandType::TENSOR_INT32,
2273 OperandType::TENSOR_INT32,
2274 };
2275 outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED};
2276 } else {
2277 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2278 }
2279 if (inputIndexes.size() == 4) {
2280 inExpectedTypes.push_back(OperandType::BOOL);
2281 version = combineVersions(version, Version::ANDROID_Q);
2282 } else {
2283 version = combineVersions(version, Version::ANDROID_P);
2284 }
2285 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2286 outputIndexes, outExpectedTypes));
2287 return version;
2288 }
2289 case OperationType::PAD: {
2290 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2291 << invalidInOutNumberMessage(2, 1);
2292 auto inputType = operands[inputIndexes[0]].type;
2293 Version version;
2294 std::vector<OperandType> inExpectedTypes;
2295 std::vector<OperandType> outExpectedTypes;
2296 if (inputType == OperandType::TENSOR_FLOAT32) {
2297 version = Version::ANDROID_P;
2298 inExpectedTypes = {
2299 OperandType::TENSOR_FLOAT32,
2300 OperandType::TENSOR_INT32,
2301 };
2302 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2303 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2304 version = Version::ANDROID_Q;
2305 inExpectedTypes = {
2306 OperandType::TENSOR_FLOAT16,
2307 OperandType::TENSOR_INT32,
2308 };
2309 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2310 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2311 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2312 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2313 version = Version::ANDROID_R;
2314 } else {
2315 if (operands[inputIndexes[0]].zeroPoint == 0) {
2316 version = Version::ANDROID_P;
2317 } else {
2318 version = Version::ANDROID_Q;
2319 }
2320 }
2321 inExpectedTypes = {
2322 inputType,
2323 OperandType::TENSOR_INT32,
2324 };
2325 outExpectedTypes = {inputType};
2326 } else {
2327 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2328 }
2329 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
2330 NN_VALIDATE_LE(inputRank, 4u)
2331 << "Unsupported input tensor rank for operation " << opType;
2332 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2333 outputIndexes, outExpectedTypes));
2334 return version;
2335 }
2336 case OperationType::PAD_V2: {
2337 NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
2338 << invalidInOutNumberMessage(3, 1);
2339 auto inputType = operands[inputIndexes[0]].type;
2340 Version version;
2341 std::vector<OperandType> inExpectedTypes;
2342 std::vector<OperandType> outExpectedTypes;
2343 if (inputType == OperandType::TENSOR_FLOAT32) {
2344 version = Version::ANDROID_Q;
2345 inExpectedTypes = {
2346 OperandType::TENSOR_FLOAT32,
2347 OperandType::TENSOR_INT32,
2348 OperandType::FLOAT32,
2349 };
2350 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2351 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2352 version = Version::ANDROID_Q;
2353 inExpectedTypes = {
2354 OperandType::TENSOR_FLOAT16,
2355 OperandType::TENSOR_INT32,
2356 OperandType::FLOAT16,
2357 };
2358 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2359 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2360 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2361 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2362 version = Version::ANDROID_R;
2363 } else {
2364 version = Version::ANDROID_Q;
2365 }
2366 inExpectedTypes = {
2367 inputType,
2368 OperandType::TENSOR_INT32,
2369 OperandType::INT32,
2370 }; // TODO(b/116699425): Make it UINT8.
2371 outExpectedTypes = {inputType};
2372 } else {
2373 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2374 }
2375 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
2376 NN_VALIDATE_LE(inputRank, 4u)
2377 << "Unsupported input tensor rank for operation " << opType;
2378 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2379 outputIndexes, outExpectedTypes));
2380 return version;
2381 }
2382 case OperationType::CAST: {
2383 NN_VALIDATE(inputIndexes.size() == 1 && outputIndexes.size() == 1)
2384 << invalidInOutNumberMessage(1, 1);
2385 auto inputOperand = operands[inputIndexes[0]];
2386 auto outputOperand = operands[outputIndexes[0]];
2387 auto inputType = inputOperand.type;
2388 auto outputType = outputOperand.type;
2389 Version version;
2390 std::vector<OperandType> inExpectedTypes;
2391 std::vector<OperandType> outExpectedTypes;
2392 if ((inputType == OperandType::TENSOR_FLOAT16 ||
2393 inputType == OperandType::TENSOR_FLOAT32 ||
2394 inputType == OperandType::TENSOR_INT32 ||
2395 inputType == OperandType::TENSOR_QUANT8_ASYMM) &&
2396 (outputType == OperandType::TENSOR_FLOAT16 ||
2397 outputType == OperandType::TENSOR_FLOAT32 ||
2398 outputType == OperandType::TENSOR_INT32 ||
2399 outputType == OperandType::TENSOR_QUANT8_ASYMM)) {
2400 version = Version::ANDROID_Q;
2401 inExpectedTypes = {inputType};
2402 outExpectedTypes = {outputType};
2403 } else if (inputType == OperandType::TENSOR_BOOL8 ||
2404 inputType == OperandType::TENSOR_QUANT16_ASYMM ||
2405 inputType == OperandType::TENSOR_QUANT16_SYMM ||
2406 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
2407 inputType == OperandType::TENSOR_QUANT8_SYMM) {
2408 version = Version::ANDROID_R;
2409 inExpectedTypes = {inputType};
2410 outExpectedTypes = {inputType}; // Only identity CAST is supported.
2411 } else {
2412 NN_VALIDATE_FAIL() << "Unsupported data type for operation " << opType;
2413 }
2414 // Validate that output shape is equal to input shape if dimensions
2415 // are already known.
2416 auto getNumberOfElements = [](const std::vector<uint32_t>& dims) {
2417 if (dims.empty()) {
2418 return 0;
2419 }
2420 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
2421 };
2422 NN_VALIDATE(inputOperand.dimensions.empty() || outputOperand.dimensions.empty() ||
2423 getNumberOfElements(outputOperand.dimensions) == 0 ||
2424 inputOperand.dimensions == outputOperand.dimensions);
2425 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2426 outputIndexes, outExpectedTypes));
2427 return version;
2428 }
2429 case OperationType::MEAN: {
2430 NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1)
2431 << invalidInOutNumberMessage(3, 1);
2432 const auto inputRank = operands[inputIndexes[0]].dimensions.size();
2433 NN_VALIDATE_LE(inputRank, 4u)
2434 << "Unsupported input tensor rank for operation " << opType;
2435 auto inputType = operands[inputIndexes[0]].type;
2436 Version version;
2437 if (inputType == OperandType::TENSOR_FLOAT32 ||
2438 inputType == OperandType::TENSOR_QUANT8_ASYMM) {
2439 version = Version::ANDROID_P;
2440 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2441 version = Version::ANDROID_Q;
2442 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2443 version = Version::ANDROID_R;
2444 } else {
2445 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2446 }
2447 std::vector<OperandType> inExpectedTypes = {inputType, OperandType::TENSOR_INT32,
2448 OperandType::INT32};
2449 std::vector<OperandType> outExpectedTypes = {inputType};
2450 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2451 outputIndexes, outExpectedTypes));
2452 return version;
2453 }
2454 case OperationType::ARGMAX:
2455 case OperationType::ARGMIN: {
2456 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2457 << invalidInOutNumberMessage(2, 1);
2458 auto inputType = operands[inputIndexes[0]].type;
2459 std::vector<OperandType> inExpectedTypes;
2460 std::vector<OperandType> outExpectedTypes;
2461 if (inputType == OperandType::TENSOR_FLOAT16 ||
2462 inputType == OperandType::TENSOR_FLOAT32 ||
2463 inputType == OperandType::TENSOR_INT32 ||
2464 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2465 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2466 inExpectedTypes = {inputType, OperandType::INT32};
2467 outExpectedTypes = {OperandType::TENSOR_INT32};
2468 } else {
2469 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2470 }
2471 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2472 outputIndexes, outExpectedTypes));
2473 return Version::ANDROID_Q;
2474 }
2475 case OperationType::EXPAND_DIMS: {
2476 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2477 << invalidInOutNumberMessage(2, 1);
2478 auto inputType = operands[inputIndexes[0]].type;
2479 std::vector<OperandType> inExpectedTypes;
2480 std::vector<OperandType> outExpectedTypes;
2481 if (inputType == OperandType::TENSOR_FLOAT16 ||
2482 inputType == OperandType::TENSOR_FLOAT32 ||
2483 inputType == OperandType::TENSOR_INT32 ||
2484 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2485 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2486 inExpectedTypes = {inputType, OperandType::INT32};
2487 outExpectedTypes = {inputType};
2488 } else {
2489 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2490 }
2491 Version version;
2492 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2493 version = Version::ANDROID_R;
2494 } else {
2495 version = Version::ANDROID_Q;
2496 }
2497 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2498 outputIndexes, outExpectedTypes));
2499 return version;
2500 }
2501 case OperationType::SPLIT: {
2502 NN_VALIDATE_EQ(inputIndexes.size(), 3u)
2503 << "Invalid number of input operands (" << inputIndexes.size()
2504 << ", expected 3)" << opType;
2505 auto inputType = operands[inputIndexes[0]].type;
2506 NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 ||
2507 inputType == OperandType::TENSOR_FLOAT32 ||
2508 inputType == OperandType::TENSOR_INT32 ||
2509 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2510 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
2511 << "Unsupported input tensor type for operation " << opType;
2512 Version version;
2513 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2514 version = Version::ANDROID_R;
2515 } else {
2516 version = Version::ANDROID_Q;
2517 }
2518 std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32,
2519 OperandType::INT32};
2520 std::vector<OperandType> outExpectedTypes(outputIndexes.size(), inputType);
2521 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2522 outputIndexes, outExpectedTypes));
2523 return version;
2524 }
2525 case OperationType::MAXIMUM:
2526 case OperationType::MINIMUM: {
2527 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2528 << invalidInOutNumberMessage(2, 1);
2529 std::vector<OperandType> inExpectedTypes;
2530 std::vector<OperandType> outExpectedTypes;
2531 OperandType inputType = operands[inputIndexes[0]].type;
2532 if (inputType == OperandType::TENSOR_FLOAT16 ||
2533 inputType == OperandType::TENSOR_FLOAT32 ||
2534 inputType == OperandType::TENSOR_INT32 ||
2535 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2536 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2537 inExpectedTypes = {inputType, inputType};
2538 outExpectedTypes = {inputType};
2539 } else {
2540 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2541 }
2542 Version version;
2543 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2544 version = Version::ANDROID_R;
2545 } else {
2546 version = Version::ANDROID_Q;
2547 }
2548 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2549 outputIndexes, outExpectedTypes));
2550 return version;
2551 }
2552 case OperationType::GROUPED_CONV_2D: {
2553 NN_VALIDATE((inputIndexes.size() == 12 || inputIndexes.size() == 9) &&
2554 outputIndexes.size() == 1)
2555 << "Invalid number of input operands (" << inputIndexes.size()
2556 << ", expected 12 or 9) or output operands (" << outputIndexes.size()
2557 << ", expected 1) for operation " << opType;
2558 auto inputType = operands[inputIndexes[0]].type;
2559 auto filterType = operands[inputIndexes[1]].type;
2560 std::vector<OperandType> inExpectedTypes;
2561 std::vector<OperandType> outExpectedTypes;
2562 if (inputType == OperandType::TENSOR_FLOAT32) {
2563 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
2564 OperandType::TENSOR_FLOAT32, OperandType::INT32,
2565 OperandType::INT32, OperandType::INT32,
2566 OperandType::INT32, OperandType::INT32};
2567 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
2568 } else if (inputType == OperandType::TENSOR_FLOAT16) {
2569 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
2570 OperandType::TENSOR_FLOAT16, OperandType::INT32,
2571 OperandType::INT32, OperandType::INT32,
2572 OperandType::INT32, OperandType::INT32};
2573 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
2574 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2575 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2576 NN_VALIDATE(filterType == inputType ||
2577 filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)
2578 << "Unsupported filter tensor type for operation " << opType;
2579
2580 NN_VALIDATE(filterType != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
2581 std::get<Operand::SymmPerChannelQuantParams>(
2582 operands[inputIndexes[1]].extraParams)
2583 .channelDim == 0)
2584 << "Unsupported filter tensor channel dimension for operation " << opType;
2585
2586 inExpectedTypes = {
2587 inputType, filterType, OperandType::TENSOR_INT32,
2588 OperandType::INT32, OperandType::INT32, OperandType::INT32,
2589 OperandType::INT32, OperandType::INT32};
2590 outExpectedTypes = {inputType};
2591 } else {
2592 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2593 }
2594
2595 if (inputIndexes.size() == 12) {
2596 std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
2597 inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
2598 explicitScalarTypes.end());
2599 }
2600 inExpectedTypes.push_back(OperandType::BOOL);
2601 Version version;
2602 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2603 version = Version::ANDROID_R;
2604 } else {
2605 version = Version::ANDROID_Q;
2606 }
2607 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2608 outputIndexes, outExpectedTypes));
2609 return version;
2610 }
2611 case OperationType::TILE: {
2612 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2613 << invalidInOutNumberMessage(2, 1);
2614 auto inputType = operands[inputIndexes[0]].type;
2615 std::vector<OperandType> inExpectedTypes;
2616 std::vector<OperandType> outExpectedTypes;
2617 if (inputType == OperandType::TENSOR_FLOAT16 ||
2618 inputType == OperandType::TENSOR_FLOAT32 ||
2619 inputType == OperandType::TENSOR_INT32 ||
2620 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
2621 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2622 inExpectedTypes = {inputType, OperandType::TENSOR_INT32};
2623 outExpectedTypes = {inputType};
2624 } else {
2625 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2626 }
2627 Version version;
2628 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2629 version = Version::ANDROID_R;
2630 } else {
2631 version = Version::ANDROID_Q;
2632 }
2633 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2634 outputIndexes, outExpectedTypes));
2635 return version;
2636 }
2637 case OperationType::POW: {
2638 NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1)
2639 << invalidInOutNumberMessage(2, 1);
2640 auto inputType = operands[inputIndexes[0]].type;
2641 std::vector<OperandType> inExpectedTypes;
2642 std::vector<OperandType> outExpectedTypes;
2643 if (inputType == OperandType::TENSOR_FLOAT16 ||
2644 inputType == OperandType::TENSOR_FLOAT32) {
2645 inExpectedTypes = {inputType, inputType};
2646 outExpectedTypes = {inputType};
2647 } else {
2648 NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType;
2649 }
2650 Version version;
2651 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
2652 version = Version::ANDROID_R;
2653 } else {
2654 version = Version::ANDROID_Q;
2655 }
2656 NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes,
2657 outputIndexes, outExpectedTypes));
2658 return version;
2659 }
2660 case OperationType::IF: {
2661 return validateIfOperation(inputIndexes, outputIndexes, operands, subgraphs);
2662 }
2663 case OperationType::WHILE: {
2664 return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs);
2665 }
2666 default: {
2667 const OperationRegistration* operationRegistration =
2668 BuiltinOperationResolver::get()->findOperation(
2669 static_cast<OperationType>(opType));
2670 // TODO: return ErrorStatus::UNEXPECTED_NULL
2671 NN_VALIDATE(operationRegistration != nullptr) << opType << " not registered";
2672 // TODO: return ErrorStatus::UNEXPECTED_NULL
2673 NN_VALIDATE(operationRegistration->validate != nullptr)
2674 << "Incomplete operation registration: " << opType;
2675
2676 OperationValidationContext context(operationRegistration->name, inputIndexes,
2677 outputIndexes, operands);
2678 return operationRegistration->validate(&context);
2679 }
2680 }
2681 }
2682
validateOperationIncludingOperandVersions(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Version> & operandVersions,const std::vector<Model::Subgraph> & subgraphs)2683 Result<Version> validateOperationIncludingOperandVersions(
2684 const Operation& operation, const std::vector<Operand>& operands,
2685 const std::vector<Version>& operandVersions,
2686 const std::vector<Model::Subgraph>& subgraphs) {
2687 auto version = NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
2688 for (uint32_t index : operation.inputs) {
2689 version = combineVersions(version, operandVersions[index]);
2690 }
2691 for (uint32_t index : operation.outputs) {
2692 version = combineVersions(version, operandVersions[index]);
2693 }
2694 return version;
2695 }
2696
2697 } // anonymous namespace
2698
2699 // Below this point are all the functions that are declared in Validation.h. The convention of this
2700 // file is to keep the function bodies of the functions declared in Validation.h minimal, meaning
2701 // that most functions below simply redirect to one of the functions defined above in the anonymous
2702 // namespace. If there is a function name clash between one of the functions below and one of the
2703 // functions above, the function in the anonymous namespace is appended with "Impl".
2704
combineVersions(Version lhs,Version rhs)2705 Version combineVersions(Version lhs, Version rhs) {
2706 return std::max<Version>(lhs, rhs);
2707 }
2708
validate(const DeviceStatus & deviceStatus)2709 Result<Version> validate(const DeviceStatus& deviceStatus) {
2710 return validateDeviceStatus(deviceStatus);
2711 }
2712
validate(const ExecutionPreference & executionPreference)2713 Result<Version> validate(const ExecutionPreference& executionPreference) {
2714 return validateExecutionPreference(executionPreference);
2715 }
2716
validate(const DeviceType & deviceType)2717 Result<Version> validate(const DeviceType& deviceType) {
2718 return validateDeviceType(deviceType);
2719 }
2720
validate(const MeasureTiming & measureTiming)2721 Result<Version> validate(const MeasureTiming& measureTiming) {
2722 return validateMeasureTiming(measureTiming);
2723 }
2724
validate(const OperandType & operandType)2725 Result<Version> validate(const OperandType& operandType) {
2726 return validateOperandType(operandType);
2727 }
2728
validate(const Priority & priority)2729 Result<Version> validate(const Priority& priority) {
2730 return validatePriority(priority);
2731 }
2732
validate(const ErrorStatus & errorStatus)2733 Result<Version> validate(const ErrorStatus& errorStatus) {
2734 return validateErrorStatus(errorStatus);
2735 }
2736
validate(const FusedActivationFunc & activation)2737 Result<Version> validate(const FusedActivationFunc& activation) {
2738 return validateFusedActivationFunc(activation);
2739 }
2740
validate(const OutputShape & outputShape)2741 Result<Version> validate(const OutputShape& outputShape) {
2742 return validateOutputShape(outputShape);
2743 }
2744
validate(const Timing & timing)2745 Result<Version> validate(const Timing& timing) {
2746 return validateTiming(timing);
2747 }
2748
validate(const Capabilities & capabilities)2749 Result<Version> validate(const Capabilities& capabilities) {
2750 return validateCapabilities(capabilities);
2751 }
2752
validate(const Extension & extension)2753 Result<Version> validate(const Extension& extension) {
2754 return validateExtension(extension);
2755 }
2756
validate(const SharedHandle & handle)2757 Result<Version> validate(const SharedHandle& handle) {
2758 return validateSharedHandle(handle);
2759 }
2760
validate(const SharedMemory & memory)2761 Result<Version> validate(const SharedMemory& memory) {
2762 return validateSharedMemory(memory);
2763 }
2764
validate(const Model & model)2765 Result<Version> validate(const Model& model) {
2766 return validateModel(model);
2767 }
2768
validate(const BufferDesc & bufferDesc)2769 Result<Version> validate(const BufferDesc& bufferDesc) {
2770 return validateBufferDesc(bufferDesc);
2771 }
2772
validate(const BufferRole & bufferRole)2773 Result<Version> validate(const BufferRole& bufferRole) {
2774 return validateBufferRole(bufferRole);
2775 }
2776
validate(const Request & request)2777 Result<Version> validate(const Request& request) {
2778 return validateRequest(request);
2779 }
2780
validate(const OptionalTimePoint & optionalTimePoint)2781 Result<Version> validate(const OptionalTimePoint& optionalTimePoint) {
2782 return validateOptionalTimePoint(optionalTimePoint);
2783 }
2784
validate(const OptionalDuration & optionalTimeoutDuration)2785 Result<Version> validate(const OptionalDuration& optionalTimeoutDuration) {
2786 return validateOptionalTimeoutDuration(optionalTimeoutDuration);
2787 }
2788
validate(const CacheToken & cacheToken)2789 Result<Version> validate(const CacheToken& cacheToken) {
2790 return validateCacheToken(cacheToken);
2791 }
2792
validate(const SyncFence & syncFence)2793 Result<Version> validate(const SyncFence& syncFence) {
2794 return validateSyncFence(syncFence);
2795 }
2796
validate(const std::vector<OutputShape> & outputShapes)2797 Result<Version> validate(const std::vector<OutputShape>& outputShapes) {
2798 return validateVector(outputShapes, validateOutputShape);
2799 }
2800
validate(const std::vector<Extension> & extensions)2801 Result<Version> validate(const std::vector<Extension>& extensions) {
2802 return validateExtensions(extensions);
2803 }
2804
validate(const std::vector<SharedHandle> & handles)2805 Result<Version> validate(const std::vector<SharedHandle>& handles) {
2806 return validateVector(handles, validateSharedHandle);
2807 }
2808
validate(const std::vector<BufferRole> & bufferRoles)2809 Result<Version> validate(const std::vector<BufferRole>& bufferRoles) {
2810 return validateVector(bufferRoles, validateBufferRole);
2811 }
2812
validate(const std::vector<SyncFence> & syncFences)2813 Result<Version> validate(const std::vector<SyncFence>& syncFences) {
2814 return validateVector(syncFences, validateSyncFence);
2815 }
2816
validateRequestForModel(const Request & request,const Model & model,bool allowUnspecifiedOutput)2817 Result<Version> validateRequestForModel(const Request& request, const Model& model,
2818 bool allowUnspecifiedOutput) {
2819 return validateRequestForModelImpl(request, model, allowUnspecifiedOutput);
2820 }
2821
validateMemoryDesc(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles,const std::function<const Model * (const SharedPreparedModel &)> & getModel,std::set<PreparedModelRole> * preparedModelRoles,Operand * combinedOperand)2822 Result<Version> validateMemoryDesc(
2823 const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels,
2824 const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles,
2825 const std::function<const Model*(const SharedPreparedModel&)>& getModel,
2826 std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand) {
2827 return validateMemoryDescImpl(desc, preparedModels, inputRoles, outputRoles, getModel,
2828 preparedModelRoles, combinedOperand);
2829 }
2830
validateOperandSymmPerChannelQuantParams(const Operand & operand,const Operand::SymmPerChannelQuantParams & channelQuant,const char * tag)2831 Result<void> validateOperandSymmPerChannelQuantParams(
2832 const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant,
2833 const char* tag) {
2834 return validateOperandSymmPerChannelQuantParamsImpl(operand, channelQuant, tag);
2835 }
2836
validateOperandType(const Operand & type,const Extension::OperandTypeInformation * extensionOperandTypeInfo,const char * tag,bool allowPartial)2837 Result<void> validateOperandType(const Operand& type,
2838 const Extension::OperandTypeInformation* extensionOperandTypeInfo,
2839 const char* tag, bool allowPartial) {
2840 return validateOperandTypeImpl(type, extensionOperandTypeInfo, tag, allowPartial);
2841 }
2842
validateOperandList(const std::vector<uint32_t> & list,size_t operandCount,const char * tag)2843 Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount,
2844 const char* tag) {
2845 return validateOperandListImpl(list, operandCount, tag);
2846 }
2847
validateOperationButNotOperands(const Operation & operation,const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs)2848 Result<void> validateOperationButNotOperands(const Operation& operation,
2849 const std::vector<Operand>& operands,
2850 const std::vector<Model::Subgraph>& subgraphs) {
2851 NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs));
2852 return {};
2853 }
2854
2855 struct SubgraphVersionCache {
2856 std::vector<std::optional<Version>> cache;
2857 };
2858
createSubgraphVersionCache(size_t referencedSubgraphCount)2859 std::unique_ptr<SubgraphVersionCache, void (*)(SubgraphVersionCache*)> createSubgraphVersionCache(
2860 size_t referencedSubgraphCount) {
2861 auto subgraphVersionCache = std::make_unique<SubgraphVersionCache>();
2862 subgraphVersionCache->cache.resize(referencedSubgraphCount);
2863 constexpr auto deleter = [](SubgraphVersionCache* ptr) { delete ptr; };
2864 return {subgraphVersionCache.release(), deleter};
2865 }
2866
validateOperationAndAnythingItDependsOn(const Operation & operation,const std::vector<Operand> & operands,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2867 Result<Version> validateOperationAndAnythingItDependsOn(
2868 const Operation& operation, const std::vector<Operand>& operands, size_t operandValuesSize,
2869 const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs,
2870 SubgraphVersionCache* subgraphVersionCache) {
2871 CHECK(subgraphVersionCache != nullptr);
2872 std::vector<Version> operandVersions(operands.size(), Version::ANDROID_OC_MR1);
2873 for (uint32_t index : operation.inputs) {
2874 NN_VALIDATE_LT(index, operands.size());
2875 const Operand& operand = operands[index];
2876 operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2877 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2878 }
2879 for (uint32_t index : operation.outputs) {
2880 NN_VALIDATE_LT(index, operands.size());
2881 const Operand& operand = operands[index];
2882 operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn(
2883 operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache));
2884 }
2885 return validateOperationIncludingOperandVersions(operation, operands, operandVersions,
2886 subgraphs);
2887 }
2888
validateOperandAndAnythingItDependsOn(const Operand & operand,size_t operandValuesSize,const std::vector<size_t> & poolSizes,const std::vector<Model::Subgraph> & subgraphs,SubgraphVersionCache * subgraphVersionCache)2889 Result<Version> validateOperandAndAnythingItDependsOn(const Operand& operand,
2890 size_t operandValuesSize,
2891 const std::vector<size_t>& poolSizes,
2892 const std::vector<Model::Subgraph>& subgraphs,
2893 SubgraphVersionCache* subgraphVersionCache) {
2894 CHECK(subgraphVersionCache != nullptr);
2895 return validateOperand(operand, operandValuesSize, poolSizes, subgraphs,
2896 &subgraphVersionCache->cache);
2897 }
2898
2899 } // namespace android::nn
2900