/* * Copyright (C) 2020 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "Validation.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "ControlFlow.h" #include "OperandTypes.h" #include "OperationResolver.h" #include "OperationTypes.h" #include "Result.h" #include "SharedMemory.h" #include "TypeUtils.h" #include "Types.h" // The NN_VALIDATE family of macros defined below is similar to the CHECK family defined in // system/libbase/include/android-base/logging.h // // The difference is that NN_VALIDATE macros use LOG(ERROR) instead of LOG(FATAL) // and return false instead of aborting. // Logs an error and returns false or INVALID. Append context using << after. For example: // // NN_VALIDATE_FAIL() << "Something went wrong"; // // The containing function must return a bool or Version. #define NN_VALIDATE_FAIL() \ return NN_ERROR() << "NN_VALIDATE failed (" << __FILE__ << ":" << __LINE__ << "): " // Logs an error and returns false or Version::INVALID if condition is false. Extra logging can be // appended using << after. For example: // // NN_VALIDATE(false) << "Something went wrong"; // // The containing function must return a bool. #define NN_VALIDATE(condition) \ while (UNLIKELY(!(condition))) NN_VALIDATE_FAIL() << #condition << " " // Helper for NN_VALIDATE_xx(x, y) macros. #define NN_VALIDATE_OP(LHS, RHS, OP) \ for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \ UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \ /* empty */) \ NN_VALIDATE_FAIL() \ << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \ << ::android::base::LogNullGuard::Guard(_values.lhs.v) \ << ", " << #RHS << " = " \ << ::android::base::LogNullGuard::Guard(_values.rhs.v) \ << ") " // Logs an error and returns false or Version::INVALID if a condition between x and y does not hold. // Extra logging can be appended using << after. For example: // // NN_VALIDATE_EQ(a, b) << "Something went wrong"; // // The values must implement the appropriate comparison operator as well as // `operator<<(std::ostream&, ...)`. // The containing function must return a bool or Version. #define NN_VALIDATE_EQ(x, y) NN_VALIDATE_OP(x, y, ==) #define NN_VALIDATE_NE(x, y) NN_VALIDATE_OP(x, y, !=) #define NN_VALIDATE_LE(x, y) NN_VALIDATE_OP(x, y, <=) #define NN_VALIDATE_LT(x, y) NN_VALIDATE_OP(x, y, <) #define NN_VALIDATE_GE(x, y) NN_VALIDATE_OP(x, y, >=) #define NN_VALIDATE_GT(x, y) NN_VALIDATE_OP(x, y, >) namespace android::nn { namespace { constexpr auto kNullptrVariant = std::variant{}; constexpr auto kInvalidMemoryDomainToken = Request::MemoryDomainToken{}; template Result validateVector(const std::vector& objects, const ValidationFunction& validationFunction) { auto version = Version::ANDROID_OC_MR1; for (const auto& object : objects) { version = combineVersions(version, NN_TRY(validationFunction(object))); } return version; } bool isValidExtensionName(const std::string& name) { constexpr auto validSymbol = [](char symbol) { return std::islower(symbol) || std::isdigit(symbol) || symbol == '.' || symbol == '_'; }; const bool hasOnlyValidSymbols = std::all_of(name.begin(), name.end(), validSymbol); const bool hasAtLeastOnePeriod = std::find(name.begin(), name.end(), '.') != name.end(); return hasOnlyValidSymbols && hasAtLeastOnePeriod; } Result validateDeviceStatus(const DeviceStatus& deviceStatus) { switch (deviceStatus) { case DeviceStatus::AVAILABLE: case DeviceStatus::BUSY: case DeviceStatus::OFFLINE: case DeviceStatus::UNKNOWN: return Version::ANDROID_OC_MR1; } NN_VALIDATE_FAIL() << "Invalid DeviceStatus " << deviceStatus; } Result validateExecutionPreference(const ExecutionPreference& executionPreference) { switch (executionPreference) { case ExecutionPreference::FAST_SINGLE_ANSWER: // ExecutionPreference::FAST_SINGLE_ANSWER is the default value, so it is implicitly // valid for all versions. return Version::ANDROID_OC_MR1; case ExecutionPreference::LOW_POWER: case ExecutionPreference::SUSTAINED_SPEED: return Version::ANDROID_P; } NN_VALIDATE_FAIL() << "Invalid ExecutionPreference " << executionPreference; } Result validateDeviceType(const DeviceType& deviceType) { switch (deviceType) { case DeviceType::UNKNOWN: // DeviceType was introduced in the 1.2 NN HAL. DeviceType::UNKNOWN is returned when // querying versions that are prior to the 1.2 NN HAL. DeviceType::UNKNOWN is not a // valid code to return for a driver that implement at least a 1.2 NN HAL. If we need a // range of versions, make ANDROID_Q (NN HAL 1.2) the exclusive upper bound for // DeviceType::UNKNOWN. return Version::ANDROID_OC_MR1; case DeviceType::OTHER: case DeviceType::CPU: case DeviceType::GPU: case DeviceType::ACCELERATOR: return Version::ANDROID_Q; } NN_VALIDATE_FAIL() << "Invalid DeviceType " << deviceType; } Result validateMeasureTiming(const MeasureTiming& measureTiming) { switch (measureTiming) { case MeasureTiming::NO: // MeasureTiming::NO is the default value, so it is implicitly valid for all versions. return Version::ANDROID_OC_MR1; case MeasureTiming::YES: return Version::ANDROID_Q; } NN_VALIDATE_FAIL() << "Invalid MeasureTiming " << measureTiming; } Result validateOperandType(const OperandType& operandType) { switch (operandType) { case OperandType::FLOAT32: case OperandType::INT32: case OperandType::UINT32: case OperandType::TENSOR_FLOAT32: case OperandType::TENSOR_INT32: case OperandType::TENSOR_QUANT8_ASYMM: case OperandType::OEM: case OperandType::TENSOR_OEM_BYTE: return Version::ANDROID_OC_MR1; case OperandType::BOOL: case OperandType::TENSOR_QUANT16_SYMM: case OperandType::TENSOR_FLOAT16: case OperandType::TENSOR_BOOL8: case OperandType::FLOAT16: case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: case OperandType::TENSOR_QUANT16_ASYMM: case OperandType::TENSOR_QUANT8_SYMM: return Version::ANDROID_Q; case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: case OperandType::SUBGRAPH: return Version::ANDROID_R; } if (isExtension(operandType)) { return Version::ANDROID_Q; } NN_VALIDATE_FAIL() << "Invalid OperandType " << operandType; } Result validateOperandLifeTime(const Operand& operand) { // Make sure SUBGRAPH operand type and lifetime always go together. NN_VALIDATE_EQ((operand.type == OperandType::SUBGRAPH), (operand.lifetime == Operand::LifeTime::SUBGRAPH)) << "Operand of type " << operand.type << " cannot have lifetime " << operand.lifetime; switch (operand.lifetime) { case Operand::LifeTime::TEMPORARY_VARIABLE: case Operand::LifeTime::SUBGRAPH_INPUT: case Operand::LifeTime::SUBGRAPH_OUTPUT: case Operand::LifeTime::CONSTANT_COPY: case Operand::LifeTime::CONSTANT_REFERENCE: case Operand::LifeTime::NO_VALUE: case Operand::LifeTime::POINTER: return Version::ANDROID_OC_MR1; case Operand::LifeTime::SUBGRAPH: return Version::ANDROID_R; } NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime; } Result validatePriority(const Priority& priority) { switch (priority) { case Priority::MEDIUM: // Priority::MEDIUM is the default value, so it is implicitly valid for all versions. return Version::ANDROID_OC_MR1; case Priority::LOW: case Priority::HIGH: return Version::ANDROID_R; } NN_VALIDATE_FAIL() << "Invalid Priority " << priority; } Result validateErrorStatus(const ErrorStatus& errorStatus) { // Note that MISSED_DEADLINE_*, RESOURCE_EXHAUSTED_*, and DEAD_OBJECT were introduced ih // ANDROID_R, but these can be cast to ANDROID_OC_MR1 as GENERAL_FAILURE. switch (errorStatus) { case ErrorStatus::NONE: case ErrorStatus::DEVICE_UNAVAILABLE: case ErrorStatus::GENERAL_FAILURE: case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE: case ErrorStatus::INVALID_ARGUMENT: case ErrorStatus::MISSED_DEADLINE_TRANSIENT: case ErrorStatus::MISSED_DEADLINE_PERSISTENT: case ErrorStatus::RESOURCE_EXHAUSTED_TRANSIENT: case ErrorStatus::RESOURCE_EXHAUSTED_PERSISTENT: case ErrorStatus::DEAD_OBJECT: return Version::ANDROID_OC_MR1; } NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus; } Result validateFusedActivationFunc(const FusedActivationFunc& activation) { switch (activation) { case FusedActivationFunc::NONE: case FusedActivationFunc::RELU: case FusedActivationFunc::RELU1: case FusedActivationFunc::RELU6: return Version::ANDROID_OC_MR1; } NN_VALIDATE_FAIL() << "Invalid FusedActivationFunc " << activation; } Result validateOutputShape(const OutputShape& /*outputShape*/) { return Version::ANDROID_Q; } Result validateTiming(const Timing& timing) { constexpr auto kNoTiming = Timing{}; if (timing == kNoTiming) { // kNoTiming is the default value, so it is implicitly valid for all versions. return Version::ANDROID_OC_MR1; } if (timing.timeInDriver.has_value() && timing.timeOnDevice.has_value()) { // `lazyMessage` is a lazy function to produce the timing validation error message. // Currently, the code is not able to inline the message in NN_VALIDATE due to a // argument-dependent lookup issue with nn::detail::ErrorBuilder interacting with std types // such as std::chrono::duration, so this function uses an indirection through // std::ostringstream. const auto lazyMessage = [&timing]() -> std::string { std::ostringstream oss; oss << "Timing::timeOnDevice (" << timing.timeOnDevice.value() << ") must not exceed Timing::timeInDriver (" << timing.timeInDriver.value() << ")"; return oss.str(); }; NN_VALIDATE(timing.timeOnDevice.value() <= timing.timeInDriver.value()) << lazyMessage(); } return Version::ANDROID_Q; } Result validateCapabilitiesPerformanceInfo( const Capabilities::PerformanceInfo& performanceInfo) { NN_VALIDATE_GT(performanceInfo.execTime, 0.0f); NN_VALIDATE_GT(performanceInfo.powerUsage, 0.0f); return Version::ANDROID_OC_MR1; } Result validateCapabilitiesOperandPerformance( const Capabilities::OperandPerformance& operandPerformance) { auto version = NN_TRY(validateOperandType(operandPerformance.type)); return combineVersions(version, NN_TRY(validateCapabilitiesPerformanceInfo(operandPerformance.info))); } Result validateCapabilitiesOperandPerformanceTable( const Capabilities::OperandPerformanceTable& operandPerformances) { // OperandPerformanceTable's order was validated when it was created, and it is castable to any // version. If an OperandType does not exist in the lower version being converted to, that // OperandPerformance will be dropped. NN_TRY(validateVector(operandPerformances.asVector(), validateCapabilitiesOperandPerformance)); return Version::ANDROID_OC_MR1; } Result validateCapabilities(const Capabilities& capabilities) { auto version = NN_TRY(validateCapabilitiesOperandPerformanceTable(capabilities.operandPerformance)); version = combineVersions(version, NN_TRY(validateCapabilitiesPerformanceInfo( capabilities.relaxedFloat32toFloat16PerformanceScalar))); version = combineVersions(version, NN_TRY(validateCapabilitiesPerformanceInfo( capabilities.relaxedFloat32toFloat16PerformanceTensor))); version = combineVersions( version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.ifPerformance))); version = combineVersions( version, NN_TRY(validateCapabilitiesPerformanceInfo(capabilities.whilePerformance))); return version; } Result validateExtensionOperandTypeInformation( const Extension::OperandTypeInformation& operandTypeInformation) { NN_VALIDATE_GT(operandTypeInformation.byteSize, 0u); return Version::ANDROID_Q; } Result validateExtension(const Extension& extension) { NN_VALIDATE(isValidExtensionName(extension.name)); // Verify all OperandTypeInformations have unique types. std::vector types; types.reserve(extension.operandTypes.size()); std::transform(extension.operandTypes.begin(), extension.operandTypes.end(), std::back_inserter(types), [](const Extension::OperandTypeInformation& operandTypeInformation) { return operandTypeInformation.type; }); std::sort(types.begin(), types.end()); const auto iter = std::adjacent_find(types.begin(), types.end()); NN_VALIDATE(iter == types.end()) << "Extension has duplicate type " << *iter; return combineVersions(Version::ANDROID_Q, NN_TRY(validateVector(extension.operandTypes, validateExtensionOperandTypeInformation))); } Result validateExtensions(const std::vector& extensions) { const auto version = NN_TRY(validateVector(extensions, validateExtension)); // Verify all extensions have unique names. std::vector> names; names.reserve(extensions.size()); std::transform(extensions.begin(), extensions.end(), std::back_inserter(names), [](const Extension& extension) { return std::cref(extension.name); }); std::sort(names.begin(), names.end(), std::less{}); const auto nameIter = std::adjacent_find(names.begin(), names.end(), std::equal_to{}); NN_VALIDATE(nameIter == names.end()) << "Two or more extensions have the duplicate name " << nameIter->get(); return version; } // Forward declaration of subgraph validation function. Result validateModelSubgraph(const Model::Subgraph& subgraph, std::optional referencedIndex, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& referenced, std::vector>* subgraphVersionCache); Result validateOperandDataLocation( const Operand& operand, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& subgraphs, std::vector>* subgraphVersionCache) { const DataLocation& location = operand.location; NN_VALIDATE_EQ(location.padding, 0u) << "DataLocation with a non-zero padding used in Model: " << location.padding; switch (operand.lifetime) { case Operand::LifeTime::CONSTANT_COPY: NN_VALIDATE(location.pointer == kNullptrVariant) << "CONSTANT_COPY with a non-null pointer"; NN_VALIDATE_EQ(location.poolIndex, 0u) << "CONSTANT_COPY with a non-zero poolIndex " << location.poolIndex; // Do the addition using uint64_t to avoid potential wrap-around problems. NN_VALIDATE_LE(static_cast(location.offset) + location.length, operandValuesSize) << "OperandValue location out of range. Starts at " << location.offset << ", length " << location.length << ", max " << operandValuesSize; return Version::ANDROID_OC_MR1; case Operand::LifeTime::CONSTANT_REFERENCE: NN_VALIDATE_LT(location.poolIndex, poolSizes.size()); // Do the addition using uint64_t to avoid potential wrap-around problems. NN_VALIDATE_LE(static_cast(location.offset) + location.length, poolSizes[location.poolIndex]) << "OperandValue location out of range. Starts at " << location.offset << ", length " << location.length << ", max " << poolSizes[location.poolIndex]; return Version::ANDROID_OC_MR1; case Operand::LifeTime::TEMPORARY_VARIABLE: case Operand::LifeTime::SUBGRAPH_INPUT: case Operand::LifeTime::SUBGRAPH_OUTPUT: case Operand::LifeTime::NO_VALUE: NN_VALIDATE(location.pointer == kNullptrVariant) << "Unexpected pointer value for operand of lifetime " << operand.lifetime; NN_VALIDATE_EQ(location.poolIndex, 0u) << "Unexpected poolIndex " << location.poolIndex << " for operand of lifetime " << operand.lifetime; NN_VALIDATE_EQ(location.offset, 0u) << "Unexpected offset " << location.offset << " for operand of lifetime " << operand.lifetime; NN_VALIDATE_EQ(location.length, 0u) << "Unexpected length " << location.length << " for operand of lifetime " << operand.lifetime; return Version::ANDROID_OC_MR1; case Operand::LifeTime::SUBGRAPH: { NN_VALIDATE(location.pointer == kNullptrVariant) << "SUBGRAPH with a non-null pointer"; NN_VALIDATE_EQ(location.poolIndex, 0u) << "SUBGRAPH with a non-zero poolIndex " << location.poolIndex; NN_VALIDATE_LT(location.offset, subgraphs.size()) << "Subgraph index out of range: " << location.offset << " >= " << subgraphs.size(); NN_VALIDATE_EQ(location.length, 0u) << "SUBGRAPH with a non-zero length " << location.length; const auto version = NN_TRY(validateModelSubgraph( subgraphs[location.offset], location.offset, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache)); return combineVersions(version, Version::ANDROID_R); } case Operand::LifeTime::POINTER: { const bool nonNull = std::visit([](auto* ptr) { return ptr != nullptr; }, location.pointer); NN_VALIDATE(nonNull) << "POINTER with a null pointer"; NN_VALIDATE_EQ(location.poolIndex, 0u) << "POINTER with a non-zero poolIndex " << location.poolIndex; NN_VALIDATE_EQ(location.offset, 0u) << "POINTER with a non-zero offset " << location.offset; return Version::ANDROID_OC_MR1; } } NN_VALIDATE_FAIL() << "Invalid Operand::LifeTime " << operand.lifetime; } Result validateOperandDimensions(const Operand& operand) { switch (operand.type) { case OperandType::FLOAT32: case OperandType::INT32: case OperandType::UINT32: case OperandType::BOOL: case OperandType::FLOAT16: case OperandType::SUBGRAPH: case OperandType::OEM: NN_VALIDATE(operand.dimensions.empty()) << "Scalar data has dimensions of rank " << operand.dimensions.size(); return Version::ANDROID_OC_MR1; case OperandType::TENSOR_FLOAT32: case OperandType::TENSOR_INT32: case OperandType::TENSOR_QUANT8_ASYMM: case OperandType::TENSOR_QUANT16_SYMM: case OperandType::TENSOR_FLOAT16: case OperandType::TENSOR_BOOL8: case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: case OperandType::TENSOR_QUANT16_ASYMM: case OperandType::TENSOR_QUANT8_SYMM: case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: case OperandType::TENSOR_OEM_BYTE: { if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY || operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE || operand.lifetime == Operand::LifeTime::POINTER) { NN_VALIDATE(!operand.dimensions.empty()) << "Tensor has lifetime of " << operand.lifetime << " but dimensions of rank 0"; const auto size = getNonExtensionSize(operand); NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow"; NN_VALIDATE_NE(size.value(), 0u) << "Tensor has at least one unknown dimension"; } // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before // Android Q? if (operand.dimensions.empty()) { // Unspecified rank was added in Android Q. return Version::ANDROID_Q; } return Version::ANDROID_OC_MR1; } } if (isExtension(operand.type)) { // Extension types were added in Android Q. return Version::ANDROID_Q; } NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; } Result validateOperandScale(const Operand& operand) { switch (operand.type) { case OperandType::FLOAT32: case OperandType::INT32: case OperandType::UINT32: case OperandType::TENSOR_FLOAT32: case OperandType::BOOL: case OperandType::TENSOR_FLOAT16: case OperandType::TENSOR_BOOL8: case OperandType::FLOAT16: case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: case OperandType::SUBGRAPH: NN_VALIDATE_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type << " with a non-zero scale (" << operand.scale << ")"; return Version::ANDROID_OC_MR1; case OperandType::TENSOR_INT32: // TENSOR_INT32 may be used with or without scale, depending on the operation. // TODO(b/119869082) We should have a separate type for TENSOR_INT32 with a scale. NN_VALIDATE_GE(operand.scale, 0.0f) << "Operand of type " << operand.type << " with a negative scale"; return Version::ANDROID_OC_MR1; case OperandType::TENSOR_QUANT8_ASYMM: case OperandType::TENSOR_QUANT16_SYMM: case OperandType::TENSOR_QUANT16_ASYMM: case OperandType::TENSOR_QUANT8_SYMM: case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: NN_VALIDATE_GT(operand.scale, 0.0f) << "Operand of type " << operand.type << " with a non-positive scale"; return Version::ANDROID_OC_MR1; case OperandType::OEM: case OperandType::TENSOR_OEM_BYTE: // No validation for OEM types. return Version::ANDROID_OC_MR1; } if (isExtension(operand.type)) { NN_VALIDATE_EQ(operand.scale, 0.0f) << "Operand of type " << operand.type << " with a non-zero scale (" << operand.scale << ")"; return Version::ANDROID_Q; } NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; } Result validateOperandZeroPoint(const Operand& operand) { switch (operand.type) { case OperandType::FLOAT32: case OperandType::INT32: case OperandType::UINT32: case OperandType::TENSOR_FLOAT32: case OperandType::TENSOR_INT32: case OperandType::BOOL: case OperandType::TENSOR_FLOAT16: case OperandType::TENSOR_BOOL8: case OperandType::FLOAT16: case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: case OperandType::TENSOR_QUANT8_SYMM: case OperandType::SUBGRAPH: NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type << " with a non-zero zeroPoint " << operand.zeroPoint; return Version::ANDROID_OC_MR1; case OperandType::TENSOR_QUANT8_ASYMM: NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 255) << "Operand of type " << operand.type << " with an invalid zeroPoint " << operand.zeroPoint << ", must be in range [0, 255]"; return Version::ANDROID_OC_MR1; case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: NN_VALIDATE(operand.zeroPoint >= -128 && operand.zeroPoint <= 127) << "Operand of type " << operand.type << " with an invalid zeroPoint " << operand.zeroPoint << ", must be in range [-128, 127]"; return Version::ANDROID_OC_MR1; case OperandType::TENSOR_QUANT16_ASYMM: NN_VALIDATE(operand.zeroPoint >= 0 && operand.zeroPoint <= 65535) << "Operand of type " << operand.type << " with an invalid zeroPoint " << operand.zeroPoint << ", must be in range [0, 65535]"; return Version::ANDROID_OC_MR1; case OperandType::TENSOR_QUANT16_SYMM: NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type << " with a non-zero zeroPoint " << operand.zeroPoint; return Version::ANDROID_OC_MR1; case OperandType::OEM: case OperandType::TENSOR_OEM_BYTE: // No validation for OEM types. return Version::ANDROID_OC_MR1; } if (isExtension(operand.type)) { NN_VALIDATE_EQ(operand.zeroPoint, 0) << "Operand of type " << operand.type << " with a non-zero zeroPoint " << operand.zeroPoint; return Version::ANDROID_Q; } NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; } Result validateOperandExtraParams(const Operand& operand) { switch (operand.type) { case OperandType::FLOAT32: case OperandType::INT32: case OperandType::UINT32: case OperandType::TENSOR_FLOAT32: case OperandType::TENSOR_INT32: case OperandType::TENSOR_QUANT8_ASYMM: case OperandType::BOOL: case OperandType::TENSOR_QUANT16_SYMM: case OperandType::TENSOR_FLOAT16: case OperandType::TENSOR_BOOL8: case OperandType::FLOAT16: case OperandType::TENSOR_QUANT16_ASYMM: case OperandType::TENSOR_QUANT8_SYMM: case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: case OperandType::SUBGRAPH: NN_VALIDATE(std::holds_alternative(operand.extraParams)) << "Operand of type " << operand.type << " has extraParams when there must be none"; return Version::ANDROID_OC_MR1; case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: { NN_VALIDATE( std::holds_alternative(operand.extraParams)) << "Operand of type " << operand.type << " without a Channel Quantization params"; const auto& channelQuant = std::get(operand.extraParams); const size_t count = operand.dimensions.size(); NN_VALIDATE_LT(channelQuant.channelDim, count) << "Operand of type " << operand.type << " with an invalid channelQuant.channelDim " << channelQuant.channelDim << ", must be valid dimension index in range [0, " << count << ")"; const uint32_t expected = operand.dimensions[channelQuant.channelDim]; NN_VALIDATE_EQ(channelQuant.scales.size(), expected) << "Operand of type " << operand.type << " with a wrong-sized scales, expected " << expected << " was " << channelQuant.scales.size(); NN_VALIDATE_NE(expected, 0u) << "Operand of type " << operand.type << " channel dimension " << channelQuant.channelDim << " is underspecified (can't be 0)"; for (uint32_t i = 0; i < expected; ++i) { NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) << "Operand of type " << operand.type << " with a non-positive value in scales[" << i << "]=" << channelQuant.scales[i]; } return Version::ANDROID_Q; } case OperandType::OEM: case OperandType::TENSOR_OEM_BYTE: // No validation for OEM types. return Version::ANDROID_OC_MR1; } if (isExtension(operand.type)) { NN_VALIDATE(std::holds_alternative(operand.extraParams) || std::holds_alternative(operand.extraParams)) << "Extension operand of type " << operand.type << " must not have SymmPerChannelQuant extraParams"; return Version::ANDROID_OC_MR1; } NN_VALIDATE_FAIL() << "Invalid OperandType " << operand.type; } Result validateOperand(const Operand& operand, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& subgraphs, std::vector>* subgraphVersionCache) { auto version = NN_TRY(validateOperandType(operand.type)); version = combineVersions(version, NN_TRY(validateOperandLifeTime(operand))); version = combineVersions(version, NN_TRY(validateOperandDimensions(operand))); version = combineVersions(version, NN_TRY(validateOperandScale(operand))); version = combineVersions(version, NN_TRY(validateOperandZeroPoint(operand))); version = combineVersions(version, NN_TRY(validateOperandExtraParams(operand))); version = combineVersions( version, NN_TRY(validateOperandDataLocation(operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache))); // For constants, validate that the length is as expected. The other lifetimes // expect the length to be 0. Don't validate for OEM types. if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE || operand.lifetime == Operand::LifeTime::CONSTANT_COPY || operand.lifetime == Operand::LifeTime::POINTER) { if (!isExtension(operand.type) && operand.type != OperandType::OEM && operand.type != OperandType::TENSOR_OEM_BYTE) { const auto expectedLength = getNonExtensionSize(operand).value(); NN_VALIDATE_EQ(operand.location.length, expectedLength) << "For operand " << operand.type << " expected a size of " << expectedLength << " but got " << operand.location.length; } } return version; } Result> validateOperands( const std::vector& operands, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& subgraphs, std::vector>* subgraphVersionCache) { std::vector versions; versions.reserve(operands.size()); for (size_t i = 0; i < operands.size(); ++i) { auto result = validateOperand(operands[i], operandValuesSize, poolSizes, subgraphs, subgraphVersionCache); if (!result.has_value()) { return error() << std::move(result).error() << " for operand " << i; } versions.push_back(result.value()); } return versions; } // Forward declaration. Result validateOperationIncludingOperandVersions( const Operation& operation, const std::vector& operands, const std::vector& operandVersions, const std::vector& subgraphs); Result validateOperations(const std::vector& operations, const std::vector& operands, const std::vector& operandVersions, const std::vector& subgraphs) { auto version = Version::ANDROID_OC_MR1; for (size_t i = 0; i < operations.size(); ++i) { auto result = validateOperationIncludingOperandVersions(operations[i], operands, operandVersions, subgraphs); if (!result.has_value()) { return error() << std::move(result).error() << " for operation " << i; } version = combineVersions(version, result.value()); } return version; } Result validateHandle(const Handle& handle) { NN_VALIDATE(std::all_of(handle.fds.begin(), handle.fds.end(), [](const base::unique_fd& fd) { return fd.ok(); })); return Version::ANDROID_OC_MR1; } Result validateSharedHandle(const SharedHandle& handle) { NN_VALIDATE(handle != nullptr); return validateHandle(*handle); } Result validateMemory(const Memory::Ashmem& memory) { NN_VALIDATE(memory.fd.ok()); NN_VALIDATE_NE(memory.size, 0u); return Version::ANDROID_OC_MR1; } Result validateMemory(const Memory::Fd& memory) { NN_VALIDATE(memory.fd.ok()); NN_VALIDATE_NE(memory.size, 0u); // `prot` is allowed to be either PROT_NONE (which has a value of 0) or the bitwise OR of either // PROT_READ or PROT_WRITE. If any other bits are set, the `prot` field is invalid. constexpr int kAllowedBits = PROT_READ | PROT_WRITE; NN_VALIDATE_EQ(memory.prot & ~kAllowedBits, 0); return Version::ANDROID_OC_MR1; } Result validateMemory(const Memory::HardwareBuffer& memory) { NN_VALIDATE(memory.handle.get() != nullptr); return Version::ANDROID_Q; } Result validateMemory(const Memory::Unknown& memory) { NN_TRY(validateHandle(memory.handle)); return Version::ANDROID_Q; } Result validateSharedMemory(const SharedMemory& memory) { NN_VALIDATE(memory != nullptr); return std::visit([](const auto& x) { return validateMemory(x); }, memory->handle); } Result validateModelSubgraphInputOutputs(const std::vector& indexes, const std::vector& operands, Operand::LifeTime lifetime) { const size_t operandCount = operands.size(); for (uint32_t i : indexes) { NN_VALIDATE_LT(i, operandCount) << "Model " << lifetime << " input or output index out of range: " << i << "/" << operandCount; const Operand& operand = operands[i]; NN_VALIDATE_EQ(operand.lifetime, lifetime) << "Model " << lifetime << " operand " << i << " has lifetime of " << operand.lifetime << " instead of the expected " << lifetime; } std::vector sortedIndexes = indexes; std::sort(sortedIndexes.begin(), sortedIndexes.end()); const auto iter = std::adjacent_find(sortedIndexes.begin(), sortedIndexes.end()); NN_VALIDATE(iter == sortedIndexes.end()) << "Model input or output occurs multiple times: " << *iter; for (size_t i = 0; i < operands.size(); ++i) { if (operands[i].lifetime == lifetime) { const auto containsIndex = [&sortedIndexes](size_t index) { return binary_search(sortedIndexes.begin(), sortedIndexes.end(), index); }; NN_VALIDATE(containsIndex(i)) << "Operand " << i << " marked as " << lifetime << " but is not included in Model input or output indexes"; } } return {}; } Result validateExecutionOrder(const Model::Subgraph& subgraph) { // Either the operand has a known value before model execution begins, or we've seen a writer // for this operand while walking operands in execution order. Initialize to known operands. std::vector operandValueKnown; operandValueKnown.reserve(subgraph.operands.size()); std::transform(subgraph.operands.begin(), subgraph.operands.end(), std::back_inserter(operandValueKnown), [](const Operand& operand) { return operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE && operand.lifetime != Operand::LifeTime::SUBGRAPH_OUTPUT; }); // Validate that operations are sorted into execution order. // // If there is a cycle in the graph, the operations will not // appear to be sorted into execution order: Some operation will // have an input for which operandValueKnown[] is false. for (size_t i = 0; i < subgraph.operations.size(); ++i) { const auto& operation = subgraph.operations[i]; for (size_t j = 0; j < operation.inputs.size(); ++j) { const uint32_t k = operation.inputs[j]; NN_VALIDATE(operandValueKnown[k]) << "Operation " << i << " input " << j << " (operand " << k << ") is read before it is written"; } for (size_t j = 0; j < operation.outputs.size(); ++j) { const uint32_t k = operation.outputs[j]; // Assuming validateOperations() has not returned an error, we know that this output is // TEMPORARY_VARIABLE or MODEL_OUTPUT, and so the only way operandValueKnown[k] can be // true is if we've already seen a writer for this operand. NN_VALIDATE(!operandValueKnown[k]) << "Operation " << i << " output " << j << " (operand " << k << ") has already been written"; operandValueKnown[k] = true; } } // Verify all operands are written. for (size_t i = 0; i < subgraph.operands.size(); ++i) { NN_VALIDATE(operandValueKnown[i]) << "Operand " << i << " is never written"; } // TODO(b/77871786): verify that every operation has at least one output operand that is read? return {}; } // Validate a subgraph, ensuring all subgraphs it depends on are also validated. // // `referencedIndex` is empty if the subgraph being validated is the main subgraph, otherwise it is // the index of the referenced subgraph being validated. // // referenced[i] and (*subgraphVersionCache)[i] correspond to the same subgraph, and therefore // `referenced` and `subgraphVersionCache` must have the same length. Result validateModelSubgraph(const Model::Subgraph& subgraph, std::optional referencedIndex, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& referenced, std::vector>* subgraphVersionCache) { CHECK(subgraphVersionCache != nullptr); CHECK_EQ(referenced.size(), subgraphVersionCache->size()); // Quickly return if the current subgraph has already been checked for its version. if (referencedIndex.has_value()) { if (auto version = subgraphVersionCache->at(*referencedIndex)) { return *version; } } NN_VALIDATE(!subgraph.operands.empty()); NN_VALIDATE(!subgraph.operations.empty()); // TODO(b/173780642): Clarify whether subgraphs with no inputs or outputs are valid. // NN_VALIDATE(!subgraph.inputIndexes.empty()); // NN_VALIDATE(!subgraph.outputIndexes.empty()); const auto operandVersions = NN_TRY(validateOperands( subgraph.operands, operandValuesSize, poolSizes, referenced, subgraphVersionCache)); const auto operationsVersion = NN_TRY(validateOperations(subgraph.operations, subgraph.operands, operandVersions, referenced)); // Accumulate the versions from all operands and operations. const auto version = std::accumulate(operandVersions.begin(), operandVersions.end(), operationsVersion, combineVersions); NN_TRY(validateModelSubgraphInputOutputs(subgraph.inputIndexes, subgraph.operands, Operand::LifeTime::SUBGRAPH_INPUT)); NN_TRY(validateModelSubgraphInputOutputs(subgraph.outputIndexes, subgraph.operands, Operand::LifeTime::SUBGRAPH_OUTPUT)); NN_TRY(validateExecutionOrder(subgraph)); // Mark the current subgraph as having already been validated so the caller can quickly return // if this subgraph is checked again. if (referencedIndex.has_value()) { subgraphVersionCache->at(*referencedIndex) = version; } return version; } Result validateModelExtensionNamesAndPrefixes( const std::vector& extensionNamesAndPrefixes) { for (const auto& extensionNameAndPrefix : extensionNamesAndPrefixes) { NN_VALIDATE(isValidExtensionName(extensionNameAndPrefix.name)); } std::vector> names; names.reserve(extensionNamesAndPrefixes.size()); std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(), std::back_inserter(names), [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) { return std::cref(extensionNameAndPrefix.name); }); std::sort(names.begin(), names.end(), std::less{}); const auto nameIter = std::adjacent_find(names.begin(), names.end(), std::equal_to{}); NN_VALIDATE(nameIter == names.end()) << "ExtensionNamesAndPrefixes has duplicate name " << nameIter->get(); std::vector types; types.reserve(extensionNamesAndPrefixes.size()); std::transform(extensionNamesAndPrefixes.begin(), extensionNamesAndPrefixes.end(), std::back_inserter(types), [](const Model::ExtensionNameAndPrefix& extensionNameAndPrefix) { return extensionNameAndPrefix.prefix; }); std::sort(types.begin(), types.end()); const auto typeIter = std::adjacent_find(types.begin(), types.end()); NN_VALIDATE(typeIter == types.end()) << "ExtensionNamesAndPrefixes has duplicate type " << *typeIter; const bool hasExtensions = !extensionNamesAndPrefixes.empty(); return hasExtensions ? Version::ANDROID_Q : Version::ANDROID_OC_MR1; } // Makes sure the model does not contain subgraph reference cycles. // // This function verifies that referencedSubgraphs[subgraphIndex] and any subgraphs it refences do // not contain any reference cycles. `path` is used to keep track of which referenced subgraphs have // already been visited in the current recursive reference path. `verified` is a cache to keep track // of which referenced subgraphs have already been verified not to form reference cycles. // // referencedSubgraphs[i], (*path)[i], and (*verified)[i] all correspond to the same subgraph, and // therefore `referencedSubgraphs`, `path`, and `verified` must all have the same length. Result checkNoReferenceCycles(const std::vector& referencedSubgraphs, uint32_t subgraphIndex, std::vector* path, std::vector* verified) { CHECK(path != nullptr); CHECK(verified != nullptr); CHECK_EQ(referencedSubgraphs.size(), path->size()); CHECK_EQ(referencedSubgraphs.size(), verified->size()); const auto& subgraph = referencedSubgraphs.at(subgraphIndex); // Quickly return if the current subgraph has already been verified to have no reference cycles. if ((*verified)[subgraphIndex]) { return {}; } // Add the current subgraph to the path (making sure that it is not already part of the path), // and verify that all subgraphs this subgraph references do not contain cycles. The current // subgraph is removed from the path only after all subgraphs this subgraph references have been // checked. NN_VALIDATE((*path)[subgraphIndex] == false) << "Model contains a circular subgraph reference"; (*path)[subgraphIndex] = true; for (const Operand& operand : subgraph.operands) { if (operand.lifetime == Operand::LifeTime::SUBGRAPH) { const uint32_t refSubgraphIndex = operand.location.offset; NN_TRY(checkNoReferenceCycles(referencedSubgraphs, refSubgraphIndex, path, verified)); } } (*path)[subgraphIndex] = false; // Mark the current subgraph as having already been verified so the caller can quickly return if // this subgraph is checked again. (*verified)[subgraphIndex] = true; return {}; } Result checkNoReferenceCycles(const std::vector& referencedSubgraphs) { const size_t count = referencedSubgraphs.size(); std::vector path(count); std::vector verified(count); for (size_t i = 0; i < count; ++i) { NN_TRY(checkNoReferenceCycles(referencedSubgraphs, i, &path, &verified)); } return {}; } Result validateModel(const Model& model) { auto version = NN_TRY(validateVector(model.pools, validateSharedMemory)); version = combineVersions( version, NN_TRY(validateModelExtensionNamesAndPrefixes(model.extensionNameToPrefix))); // Ignore relaxComputationFloat32toFloat16 version because in the worst case it makes the // execution stricter. // Referenced models were introduced in Android R. const bool hasReferencedModels = !model.referenced.empty(); const auto referenceModelVersion = hasReferencedModels ? Version::ANDROID_R : Version::ANDROID_OC_MR1; version = combineVersions(version, referenceModelVersion); // Ensure that there are no cycles formed by the subgraphs. NN_TRY(checkNoReferenceCycles(model.referenced)); // Get memory sizes. const auto [operandValuesSize, poolSizes] = getMemorySizes(model); // Validate referenced subgraphs. auto subgraphVersionCache = std::vector>(model.referenced.size()); for (size_t referencedIndex = 0; referencedIndex < model.referenced.size(); ++referencedIndex) { const auto& subgraph = model.referenced[referencedIndex]; const auto subgraphVersion = NN_TRY(validateModelSubgraph(subgraph, referencedIndex, operandValuesSize, poolSizes, model.referenced, &subgraphVersionCache)); version = combineVersions(version, subgraphVersion); } // Validate main subgraph. const auto subgraphVersion = NN_TRY(validateModelSubgraph(model.main, std::nullopt, operandValuesSize, poolSizes, model.referenced, &subgraphVersionCache)); version = combineVersions(version, subgraphVersion); return version; } Result validateBufferDesc(const BufferDesc& bufferDesc) { // An empty BufferDesc is the default value, so it is implicitly valid for all versions. return bufferDesc.dimensions.empty() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R; } Result validateBufferRole(const BufferRole& bufferRole) { NN_VALIDATE_GT(bufferRole.probability, 0.0f); NN_VALIDATE_LE(bufferRole.probability, 1.0f); return Version::ANDROID_R; } Result validateRequestArgument(const Request::Argument& requestArgument, const std::vector& memorySizes, bool isOutput) { const auto lifetime = requestArgument.lifetime; const auto& location = requestArgument.location; const auto& dimensions = requestArgument.dimensions; switch (lifetime) { case Request::Argument::LifeTime::POOL: { NN_VALIDATE(location.pointer == kNullptrVariant); NN_VALIDATE_LT(location.poolIndex, memorySizes.size()); // Do the addition using uint64_t to avoid potential wrap-around problems. const auto lastPosition = static_cast(location.offset) + location.length + location.padding; const auto memorySize = memorySizes[location.poolIndex]; NN_VALIDATE_LE(lastPosition, memorySize); if (memorySize > 0) { // Must specify a positive length if the memory pool has a known size. NN_VALIDATE_GT(location.length, 0u); } return Version::ANDROID_OC_MR1; } case Request::Argument::LifeTime::NO_VALUE: NN_VALIDATE(location.pointer == kNullptrVariant); NN_VALIDATE_EQ(location.poolIndex, 0u); NN_VALIDATE_EQ(location.offset, 0u); NN_VALIDATE_EQ(location.length, 0u); NN_VALIDATE_EQ(location.padding, 0u); NN_VALIDATE(dimensions.empty()); return Version::ANDROID_OC_MR1; case Request::Argument::LifeTime::POINTER: { const bool isNullptr = std::visit([](auto ptr) { return ptr == nullptr; }, location.pointer); NN_VALIDATE(!isNullptr); NN_VALIDATE_EQ(location.poolIndex, 0u); NN_VALIDATE_EQ(location.offset, 0u); NN_VALIDATE_NE(location.length, 0u); if (isOutput) { NN_VALIDATE(std::holds_alternative(location.pointer)); } return Version::ANDROID_OC_MR1; } } NN_VALIDATE_FAIL() << "Invalid Request::Argument::LifeTime " << lifetime; } Result validateRequestMemoryPool(const Request::MemoryPool& memoryPool) { if (std::holds_alternative(memoryPool)) { NN_VALIDATE(std::get(memoryPool) != kInvalidMemoryDomainToken); return Version::ANDROID_R; } if (std::holds_alternative(memoryPool)) { NN_VALIDATE(std::get(memoryPool) != nullptr); return Version::ANDROID_R; } return validateSharedMemory(std::get(memoryPool)); } Result validateRequest(const Request& request) { auto version = NN_TRY(validateVector(request.pools, validateRequestMemoryPool)); // Get memory sizes. For IBuffer or MemoryDomainToken types, set size to 0. std::vector memorySizes; memorySizes.reserve(request.pools.size()); std::transform(request.pools.begin(), request.pools.end(), std::back_inserter(memorySizes), [](const Request::MemoryPool& memoryPool) { const auto* memory = std::get_if(&memoryPool); return memory != nullptr ? getSize(*memory) : 0; }); for (size_t i = 0; i < request.inputs.size(); ++i) { const auto& input = request.inputs[i]; auto result = validateRequestArgument(input, memorySizes, /*isOutput=*/false); if (!result.has_value()) { return error() << std::move(result).error() << " for input RequestArgument " << i; } version = combineVersions(version, result.value()); } for (size_t i = 0; i < request.outputs.size(); ++i) { const auto& output = request.outputs[i]; auto result = validateRequestArgument(output, memorySizes, /*isOutput=*/true); if (!result.has_value()) { return error() << std::move(result).error() << " for output RequestArgument " << i; } version = combineVersions(version, result.value()); } return version; } Result validateOptionalTimePoint(const OptionalTimePoint& optionalTimePoint) { if (optionalTimePoint.has_value()) { NN_VALIDATE_GE(optionalTimePoint->time_since_epoch().count(), 0); } // An omitted time point is the default value, so it is implicitly valid for all versions. return !optionalTimePoint.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R; } Result validateOptionalTimeoutDuration(const OptionalDuration& optionalTimeoutDuration) { if (optionalTimeoutDuration.has_value()) { NN_VALIDATE_GE(optionalTimeoutDuration->count(), 0); } // An omitted duration is the default value, so it is implicitly valid for all versions. return !optionalTimeoutDuration.has_value() ? Version::ANDROID_OC_MR1 : Version::ANDROID_R; } Result validateCacheToken(const CacheToken& cacheToken) { // A CacheToken of 0 is the default value, so it is implicitly valid for all versions. constexpr auto kDefaultCacheToken = CacheToken{}; return cacheToken == kDefaultCacheToken ? Version::ANDROID_OC_MR1 : Version::ANDROID_Q; } Result validateSyncFence(const SyncFence& syncFence) { // The absence of a sync fence is implicitly valid for all versions. if (!syncFence.hasFd()) { return Version::ANDROID_OC_MR1; } NN_VALIDATE_GE(syncFence.getFd(), 0); return Version::ANDROID_R; } Result validateRequestArgumentsForModel( const std::vector& requestArguments, const std::vector& operandIndexes, const std::vector& operands, bool isOutput, bool allowUnspecifiedOutput) { auto version = Version::ANDROID_OC_MR1; // The request should specify as many arguments as were described in the model. const std::string_view type = isOutput ? "output" : "input"; const size_t requestArgumentCount = requestArguments.size(); NN_VALIDATE_EQ(requestArgumentCount, operandIndexes.size()) << "Request specifies " << requestArgumentCount << " " << type << "s but the model has " << operandIndexes.size(); for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount; requestArgumentIndex++) { const Request::Argument& requestArgument = requestArguments[requestArgumentIndex]; // Get the operand index for this argument. We extract it from the list // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs. // We assume in this function that the model has been validated already. const uint32_t operandIndex = operandIndexes[requestArgumentIndex]; const Operand& operand = operands[operandIndex]; if (requestArgument.lifetime != Request::Argument::LifeTime::NO_VALUE) { const bool isExtensionType = isExtension(operand.type); // If the argument specified a dimension, validate it. uint32_t modelRank = operand.dimensions.size(); uint32_t requestRank = requestArgument.dimensions.size(); if (requestRank == 0) { // NOTE: validateRequestArguments cannot validate unknown tensor rank with // extension operand type. if (!isExtensionType && !isNonExtensionScalar(operand.type)) { if (modelRank <= 0) { NN_VALIDATE(isOutput) << "Model has unknown input rank but the request does not " "specify the rank."; NN_VALIDATE(allowUnspecifiedOutput) << "Model has unknown output rank and request does not specify it."; // Unspecified output dimensions introduced in Android Q. version = combineVersions(version, Version::ANDROID_Q); } } // Validate that all the dimensions are specified in the model. for (size_t i = 0; i < modelRank; i++) { if (operand.dimensions[i] == 0) { NN_VALIDATE(isOutput && allowUnspecifiedOutput) << "Model has dimension " << i << " set to 0 but the request does not specify the dimension."; // Unspecified output dimensions introduced in Android Q. version = combineVersions(version, Version::ANDROID_Q); } } } else { NN_VALIDATE(modelRank == 0 || requestRank == modelRank) << "Request " << type << " " << requestArgumentIndex << " has number of dimensions (" << requestRank << ") different than the model's (" << modelRank << ")"; for (size_t i = 0; i < requestRank; i++) { NN_VALIDATE(modelRank == 0 || operand.dimensions[i] == 0 || requestArgument.dimensions[i] == operand.dimensions[i]) << "Request " << type << " " << requestArgumentIndex << " has dimension " << i << " of " << requestArgument.dimensions[i] << " different than the model's " << operand.dimensions[i]; if (requestArgument.dimensions[i] == 0) { NN_VALIDATE(isOutput && allowUnspecifiedOutput) << "Request " << type << " " << requestArgumentIndex << " has dimension " << i << " of zero"; // Unspecified output dimensions introduced in Android Q. version = combineVersions(version, Version::ANDROID_Q); } } } // NOTE: validateRequestArguments cannot validate DataLocation::length // with extension operand type. if (!isExtensionType && requestArgument.location.length != 0) { const auto dimensions = NN_TRY(combineDimensions(operand.dimensions, requestArgument.dimensions)); const size_t expectedLength = getNonExtensionSize(operand.type, dimensions).value(); if (expectedLength != 0) { NN_VALIDATE_EQ(requestArgument.location.length, expectedLength) << "Request " << type << " " << requestArgumentIndex << " expected a size of " << expectedLength << " but got " << requestArgument.location.length; } } } } return version; } Result validateRequestForModelImpl(const Request& request, const Model& model, bool allowUnspecifiedOutput) { auto version = NN_TRY(validateRequest(request)); version = combineVersions(version, NN_TRY(validateModel(model))); version = combineVersions(version, NN_TRY(validateRequestArgumentsForModel( request.inputs, model.main.inputIndexes, model.main.operands, /*isOutput=*/false, /*allowUnspecifiedOutput=*/true))); version = combineVersions( version, NN_TRY(validateRequestArgumentsForModel( request.outputs, model.main.outputIndexes, model.main.operands, /*isOutput=*/true, allowUnspecifiedOutput))); return version; } Result validateMemoryDescImpl( const BufferDesc& desc, const std::vector& preparedModels, const std::vector& inputRoles, const std::vector& outputRoles, const std::function& getModel, std::set* preparedModelRoles, Operand* combinedOperand) { NN_VALIDATE(!preparedModels.empty()); NN_VALIDATE(!inputRoles.empty() || !outputRoles.empty()); std::set roles; std::vector operands; operands.reserve(inputRoles.size() + outputRoles.size()); for (const auto& role : inputRoles) { NN_VALIDATE_LT(role.modelIndex, preparedModels.size()); const auto& preparedModel = preparedModels[role.modelIndex]; NN_VALIDATE(preparedModel != nullptr); const auto* model = getModel(preparedModel); NN_VALIDATE(model != nullptr); const auto& inputIndexes = model->main.inputIndexes; NN_VALIDATE_LT(role.ioIndex, inputIndexes.size()); NN_VALIDATE_GT(role.probability, 0.0f); NN_VALIDATE_LE(role.probability, 1.0f); const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex); NN_VALIDATE(success); operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]); } for (const auto& role : outputRoles) { NN_VALIDATE_LT(role.modelIndex, preparedModels.size()); const auto& preparedModel = preparedModels[role.modelIndex]; NN_VALIDATE(preparedModel != nullptr); const auto* model = getModel(preparedModel); NN_VALIDATE(model != nullptr); const auto& outputIndexes = model->main.outputIndexes; NN_VALIDATE_LT(role.ioIndex, outputIndexes.size()); NN_VALIDATE_GT(role.probability, 0.0f); NN_VALIDATE_LE(role.probability, 1.0f); const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex); NN_VALIDATE(success); operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]); } CHECK(!operands.empty()); const auto opType = operands.front().type; Dimensions dimensions = desc.dimensions; for (const auto& operand : operands) { NN_VALIDATE_EQ(operand.type, opType) << operand.type << " vs " << operands.front().type; NN_VALIDATE_EQ(operand.scale, operands.front().scale); NN_VALIDATE_EQ(operand.zeroPoint, operands.front().zeroPoint); // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type. if (!isExtension(opType)) { NN_VALIDATE_EQ(operand.extraParams, operands.front().extraParams) << operand.extraParams << " vs " << operands.front().extraParams; } dimensions = NN_TRY(combineDimensions(dimensions, operand.dimensions)); } // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type. if (!isExtension(opType)) { NN_VALIDATE(!isNonExtensionScalar(opType) || dimensions.empty()) << "invalid dimensions with scalar operand type."; } if (preparedModelRoles != nullptr) { *preparedModelRoles = std::move(roles); } if (combinedOperand != nullptr) { *combinedOperand = operands.front(); combinedOperand->dimensions = dimensions; } return Version::ANDROID_R; } class OperationValidationContext : public IOperationValidationContext { DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext); public: OperationValidationContext(const char* operationName, const std::vector& inputIndexes, const std::vector& outputIndexes, const std::vector& operands) : operationName(operationName), inputIndexes(inputIndexes), outputIndexes(outputIndexes), operands(operands) {} const char* getOperationName() const override; uint32_t getNumInputs() const override; OperandType getInputType(uint32_t index) const override; Shape getInputShape(uint32_t index) const override; const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override; uint32_t getNumOutputs() const override; OperandType getOutputType(uint32_t index) const override; Shape getOutputShape(uint32_t index) const override; private: const Operand* getInputOperand(uint32_t index) const; const Operand* getOutputOperand(uint32_t index) const; const char* operationName; const std::vector& inputIndexes; const std::vector& outputIndexes; const std::vector& operands; }; const char* OperationValidationContext::getOperationName() const { return operationName; } const Operand* OperationValidationContext::getInputOperand(uint32_t index) const { return &operands.at(inputIndexes.at(index)); } const Operand* OperationValidationContext::getOutputOperand(uint32_t index) const { return &operands.at(outputIndexes.at(index)); } uint32_t OperationValidationContext::getNumInputs() const { auto count = inputIndexes.size(); CHECK_LE(count, std::numeric_limits::max()); return static_cast(count); } uint32_t OperationValidationContext::getNumOutputs() const { auto count = outputIndexes.size(); CHECK_LE(count, std::numeric_limits::max()); return static_cast(count); } OperandType OperationValidationContext::getInputType(uint32_t index) const { return getInputOperand(index)->type; } Shape OperationValidationContext::getInputShape(uint32_t index) const { const Operand* operand = getInputOperand(index); return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint, operand->extraParams}; } const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const { return getInputOperand(index)->extraParams; } OperandType OperationValidationContext::getOutputType(uint32_t index) const { return getOutputOperand(index)->type; } Shape OperationValidationContext::getOutputShape(uint32_t index) const { const Operand* operand = getOutputOperand(index); return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint, operand->extraParams}; } // TODO(b/169345292): reduce the duplicate validation here Result validateOperandSymmPerChannelQuantParamsImpl( const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, const char* tag) { if (operand.type != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) { NN_VALIDATE_FAIL(); } NN_VALIDATE_LT(channelQuant.channelDim, operand.dimensions.size()) << tag; NN_VALIDATE(!channelQuant.scales.empty()) << tag; NN_VALIDATE_EQ(channelQuant.scales.size(), operand.dimensions[channelQuant.channelDim]) << tag; NN_VALIDATE_NE(operand.dimensions[channelQuant.channelDim], 0u) << tag << " channel dimension " << channelQuant.channelDim << " is underspecified"; for (uint32_t i = 0; i < operand.dimensions[channelQuant.channelDim]; i++) { NN_VALIDATE_GT(channelQuant.scales[i], 0.0f) << tag << " invalid scaleArray[" << i << "]"; } return {}; } Result validateScalarDimensions(const Operand& type, const char* tag) { NN_VALIDATE(type.dimensions.empty()) << tag << " invalid dimensions for scalar type"; return {}; } Result validateQuant8AsymmParams(const Operand& type, const char* tag) { NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 255) << tag << " invalid zeroPoint: " << type.zeroPoint; NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; return {}; } Result validateQuant8AsymmSignedParams(const Operand& type, const char* tag) { NN_VALIDATE(-128 <= type.zeroPoint && type.zeroPoint <= 127) << tag << " invalid zeroPoint: " << type.zeroPoint; NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; return {}; } Result validateQuant8SymmParams(const Operand& type, const char* tag) { NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " invalid zeroPoint: " << type.zeroPoint; NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; return {}; } Result validateQuant16AsymmParams(const Operand& type, const char* tag) { NN_VALIDATE(0 <= type.zeroPoint && type.zeroPoint <= 65535) << tag << " invalid zeroPoint: " << type.zeroPoint; NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; return {}; } Result validateQuantSymmParams(const Operand& type, const char* tag) { NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero"; NN_VALIDATE_GT(type.scale, 0.0f) << tag << " invalid scale"; return {}; } Result validateNoQuantParams(const Operand& type, const char* tag) { NN_VALIDATE_EQ(type.zeroPoint, 0) << tag << " zeroPoint is not zero"; NN_VALIDATE_EQ(type.scale, 0.0f) << tag << " scale is not zero"; return {}; } Result validateTensorDimensions( const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo, const char* tag, bool allowPartial) { if (!allowPartial) { NN_VALIDATE(!type.dimensions.empty()) << tag << " invalid operand dimensions"; } uint64_t size = isExtension(type.type) ? extensionOperandTypeInfo->byteSize : getNonExtensionSize(type.type); constexpr uint64_t kMaxSize = std::numeric_limits::max(); for (size_t i = 0; i < type.dimensions.size(); i++) { if (!allowPartial) { NN_VALIDATE_NE(type.dimensions[i], 0u) << tag << " invalid operand dimensions"; } if (type.dimensions[i] != 0) { size *= type.dimensions[i]; NN_VALIDATE_LE(size, kMaxSize) << tag << " operand byte size exceeds " << kMaxSize; } } return {}; } Result validateOperandTypeImpl( const Operand& type, const Extension::OperandTypeInformation* const extensionOperandTypeInfo, const char* tag, bool allowPartial) { if (isExtension(type.type)) { NN_VALIDATE(extensionOperandTypeInfo != nullptr); if (extensionOperandTypeInfo->isTensor) { NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial)); } else { NN_TRY(validateScalarDimensions(type, tag)); } return validateNoQuantParams(type, tag); } NN_VALIDATE(extensionOperandTypeInfo == nullptr); NN_TRY(validateOperandType(type.type)); if (isNonExtensionScalar(type.type)) { NN_TRY(validateScalarDimensions(type, tag)); if (type.type != OperandType::OEM) { // Historically, we have allowed OEM types // to use quantization parameters. NN_TRY(validateNoQuantParams(type, tag)); } } else { NN_TRY(validateTensorDimensions(type, extensionOperandTypeInfo, tag, allowPartial)); if (type.type == OperandType::TENSOR_QUANT8_ASYMM) { NN_TRY(validateQuant8AsymmParams(type, tag)); } else if (type.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { NN_TRY(validateQuant8AsymmSignedParams(type, tag)); } else if (type.type == OperandType::TENSOR_QUANT8_SYMM) { NN_TRY(validateQuant8SymmParams(type, tag)); } else if (type.type == OperandType::TENSOR_QUANT16_ASYMM) { NN_TRY(validateQuant16AsymmParams(type, tag)); } else if (type.type == OperandType::TENSOR_QUANT16_SYMM) { NN_TRY(validateQuantSymmParams(type, tag)); } else if (type.type == OperandType::TENSOR_INT32 || type.type == OperandType::TENSOR_OEM_BYTE) { // TODO(b/119869082): TENSOR_INT32 should not use quantization parameters. // Historically, we have allowed OEM types to use quantization parameters. } else { NN_TRY(validateNoQuantParams(type, tag)); } } return {}; } Result validateOperandListImpl(const std::vector& list, size_t operandCount, const char* tag) { for (size_t i = 0; i < list.size(); i++) { NN_VALIDATE_LT(list[i], operandCount) << tag << " invalid operand index at " << i << " = " << list[i] << ", operandCount " << operandCount; } return {}; } Result validateOperationOperandTypes(const std::vector& operands, const std::vector& inputIndexes, const std::vector& inExpectedTypes, const std::vector& outputIndexes, const std::vector& outExpectedInTypes) { NN_VALIDATE_EQ(inputIndexes.size(), inExpectedTypes.size()) << "Wrong operand count: expected " << inputIndexes.size() << " inputs, got " << inputIndexes.size() << " inputs"; NN_VALIDATE_EQ(outputIndexes.size(), outExpectedInTypes.size()) << "Wrong operand count: expected " << outputIndexes.size() << " outputs, got " << outputIndexes.size() << " outputs"; for (size_t i = 0; i < inputIndexes.size(); i++) { NN_VALIDATE_EQ(operands[inputIndexes[i]].type, inExpectedTypes[i]) << "Invalid input tensor type " << operands[inputIndexes[i]].type << " for input " << i << ", expected " << inExpectedTypes[i]; } for (size_t i = 0; i < outputIndexes.size(); i++) { NN_VALIDATE_EQ(operands[outputIndexes[i]].type, outExpectedInTypes[i]) << "Invalid output tensor type " << operands[outputIndexes[i]].type << " for input " << i << ", expected " << outExpectedInTypes[i]; } return {}; } Result validateSubgraphReference(const std::vector& subgraphs, const Operand& modelOperand) { NN_VALIDATE_EQ(modelOperand.type, OperandType::SUBGRAPH) << "Unexpected operand type: " << modelOperand.type; NN_VALIDATE_LT(modelOperand.location.offset, subgraphs.size()) << "Invalid subgraph reference"; return {}; } const Model::Subgraph& getSubgraph(const std::vector& subgraphs, const Operand& modelOperand) { return subgraphs.at(modelOperand.location.offset); } uint32_t getInputCount(const std::vector& subgraphs, const Operand& modelOperand) { return getSubgraph(subgraphs, modelOperand).inputIndexes.size(); } uint32_t getOutputCount(const std::vector& subgraphs, const Operand& modelOperand) { return getSubgraph(subgraphs, modelOperand).outputIndexes.size(); } const Operand& getInputOperand(const std::vector& subgraphs, const Operand& modelOperand, uint32_t index) { const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand); return subgraph.operands.at(subgraph.inputIndexes.at(index)); } const Operand& getOutputOperand(const std::vector& subgraphs, const Operand& modelOperand, uint32_t index) { const Model::Subgraph& subgraph = getSubgraph(subgraphs, modelOperand); return subgraph.operands.at(subgraph.outputIndexes.at(index)); } // Checks if two operands have the same types, ranks (if specified), dimensions // (if specified), scales, zeroPoints, and extraParams. Result compatible(const Operand& a, const Operand& b) { NN_VALIDATE_EQ(a.type, b.type) << a.type << " != " << b.type; if (!a.dimensions.empty() && !b.dimensions.empty()) { NN_VALIDATE_EQ(a.dimensions.size(), b.dimensions.size()) << "Incompatible dimensions"; for (uint32_t i = 0, n = a.dimensions.size(); i < n; ++i) { if (a.dimensions[i] != 0 && b.dimensions[i] != 0) { NN_VALIDATE_EQ(a.dimensions[i], b.dimensions[i]) << "Incompatible dimensions"; } } } NN_VALIDATE_EQ(a.scale, b.scale); NN_VALIDATE_EQ(a.zeroPoint, b.zeroPoint); NN_VALIDATE_EQ(a.extraParams, b.extraParams) << a.extraParams << " != " << b.extraParams; return {}; } Result validateConditionOperand(const Operand& operand) { NN_VALIDATE_EQ(operand.type, OperandType::TENSOR_BOOL8) << "Unexpected condition operand type: " << operand.type; NN_VALIDATE_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton"; NN_VALIDATE_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton"; return {}; } Result validateIfOperation(const std::vector& inputs, const std::vector& outputs, const std::vector& operands, const std::vector& subgraphs) { namespace op = operation_if; NN_VALIDATE_GE(inputs.size(), 3u) << "IF must have at least 3 inputs"; NN_VALIDATE_GE(outputs.size(), 1u) << "IF must have at least 1 output"; auto validateBranchOperand = [&](const Operand& branchModelOperand) -> Result { auto result = validateSubgraphReference(subgraphs, branchModelOperand); if (!result.has_value()) { return error() << std::move(result).error() << " -- Operand is not a valid subgraph reference"; } const uint32_t branchModelInputCount = getInputCount(subgraphs, branchModelOperand); const uint32_t branchModelOutputCount = getOutputCount(subgraphs, branchModelOperand); NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + branchModelInputCount); NN_VALIDATE_EQ(outputs.size(), branchModelOutputCount); for (uint32_t i = 0; i < branchModelInputCount; ++i) { const Operand& innerOperand = getInputOperand(subgraphs, branchModelOperand, i); const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; NN_TRY(compatible(innerOperand, outerOperand)); } for (uint32_t i = 0; i < branchModelOutputCount; ++i) { const Operand& innerOperand = getOutputOperand(subgraphs, branchModelOperand, i); const Operand& outerOperand = operands[outputs[i]]; NN_TRY(compatible(innerOperand, outerOperand)); } return {}; }; auto result = validateConditionOperand(operands[inputs[op::kCondBoolOperand]]); if (!result.has_value()) { return error() << std::move(result).error() << " for IF condition operand"; } result = validateBranchOperand(operands[inputs[op::kThenModelOperand]]); if (!result.has_value()) { return error() << std::move(result).error() << " for IF then model"; } result = validateBranchOperand(operands[inputs[op::kElseModelOperand]]); if (!result.has_value()) { return error() << std::move(result).error() << " for IF else model"; } return Version::ANDROID_R; } Result validateControlFlowOperandUnknownSize(const Operand& operand) { if (!isExtension(operand.type) && getNonExtensionSize(operand).value() == 0) { // 1.3 HAL (corresponding to Version::ANDROID_R) does not support CF operations with // operands of unknown size. See http://b/132458982#comment63. return Version::CURRENT_RUNTIME; } return Version::ANDROID_R; } Result validateWhileOperation(const std::vector& inputs, const std::vector& outputs, const std::vector& operands, const std::vector& subgraphs) { // Let the loop have // - m >= 1 input-output operands, // - k >= 0 state-only operands, and // - n >= 0 input-only operands. // Then // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs. // - the condition model has (m + k + n) inputs and 1 output. // - the body model has (m + k + n) inputs and (m + k) outputs. namespace op = operation_while; NN_VALIDATE_GE(inputs.size(), 3u) << "WHILE must have at least 3 inputs"; NN_VALIDATE_GE(outputs.size(), 1u) << "WHILE must have at least 1 output"; auto validateCondOperand = [&](const Operand& condModelOperand) -> Result { Version version = Version::ANDROID_R; auto result = validateSubgraphReference(subgraphs, condModelOperand); if (!result.has_value()) { return error() << std::move(result).error() << " -- Operand is not a valid subgraph reference"; } const uint32_t condModelInputCount = getInputCount(subgraphs, condModelOperand); const uint32_t condModelOutputCount = getOutputCount(subgraphs, condModelOperand); NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + condModelInputCount); NN_VALIDATE_EQ(condModelOutputCount, 1u); for (uint32_t i = 0; i < condModelInputCount; ++i) { const Operand& innerOperand = getInputOperand(subgraphs, condModelOperand, i); const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; NN_TRY(compatible(innerOperand, outerOperand)); version = combineVersions(version, NN_TRY(validateControlFlowOperandUnknownSize(innerOperand))); version = combineVersions(version, NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); } NN_TRY(validateConditionOperand(getOutputOperand(subgraphs, condModelOperand, 0))); return version; }; auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> Result { Version version = Version::ANDROID_R; auto result = validateSubgraphReference(subgraphs, bodyModelOperand); if (!result.has_value()) { return error() << std::move(result).error() << " -- Operand is not a valid subgraph reference"; } const uint32_t bodyModelInputCount = getInputCount(subgraphs, bodyModelOperand); const uint32_t bodyModelOutputCount = getOutputCount(subgraphs, bodyModelOperand); NN_VALIDATE_EQ(inputs.size(), op::kFirstInput + bodyModelInputCount); NN_VALIDATE_GE(bodyModelOutputCount, outputs.size()); NN_VALIDATE_GE(bodyModelInputCount, bodyModelOutputCount); const uint32_t inputOutputCount = outputs.size(); const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount; const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount; for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) { const Operand& innerOperand = getInputOperand(subgraphs, bodyModelOperand, i); const Operand& outerOperand = operands[inputs[op::kFirstInput + i]]; NN_TRY(compatible(innerOperand, outerOperand)); version = combineVersions(version, NN_TRY(validateControlFlowOperandUnknownSize(innerOperand))); version = combineVersions(version, NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); } for (uint32_t i = 0; i < inputOutputCount; ++i) { const Operand& innerOperand = getOutputOperand(subgraphs, bodyModelOperand, i); const Operand& outerOperand = operands[outputs[i]]; NN_TRY(compatible(innerOperand, outerOperand)); version = combineVersions(version, NN_TRY(validateControlFlowOperandUnknownSize(outerOperand))); } for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) { const Operand& inputOperand = getInputOperand(subgraphs, bodyModelOperand, i); const Operand& outputOperand = getOutputOperand(subgraphs, bodyModelOperand, i); NN_TRY(compatible(inputOperand, outputOperand)); version = combineVersions(version, NN_TRY(validateControlFlowOperandUnknownSize(outputOperand))); } return version; }; auto result = validateCondOperand(operands[inputs[op::kCondModelOperand]]); if (!result.has_value()) { return error() << std::move(result).error() << " for WHILE condition model"; } auto version = result.value(); result = validateBodyOperand(operands[inputs[op::kBodyModelOperand]]); if (!result.has_value()) { return error() << std::move(result).error() << " for WHILE body model"; } version = combineVersions(version, result.value()); return version; } Result validateOperationButNotOperandsImpl(const Operation& operation, const std::vector& operands, const std::vector& subgraphs) { const auto opType = operation.type; const auto& inputIndexes = operation.inputs; const auto& outputIndexes = operation.outputs; NN_TRY(validateOperandListImpl(inputIndexes, operands.size(), "ANeuralNetworksModel_addOperation inputs")); NN_TRY(validateOperandListImpl(outputIndexes, operands.size(), "ANeuralNetworksModel_addOperation outputs")); if (isExtension(opType)) { // There is no other validation we can do for an extension operation. return Version::ANDROID_Q; } auto invalidInOutNumberMessage = [opType, &inputIndexes, &outputIndexes](int expIn, int expOut) { std::ostringstream os; os << "Invalid number of input operands (" << inputIndexes.size() << ", expected " << expIn << ") or output operands (" << outputIndexes.size() << ", expected " << expOut << ") for operation " << opType; return os.str(); }; switch (opType) { case OperationType::OEM_OPERATION: { return Version::ANDROID_OC_MR1; } case OperationType::RESHAPE: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); auto inputType = operands[inputIndexes[0]].type; Version version; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_OC_MR1; inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { version = Version::ANDROID_OC_MR1; inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::TENSOR_INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } const auto inputRank = operands[inputIndexes[0]].dimensions.size(); NN_VALIDATE_LE(inputRank, 4u) << "Unsupported input tensor rank for operation " << opType; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::DEPTH_TO_SPACE: { NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && outputIndexes.size() == 1) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 3 or 2) or output operands (" << outputIndexes.size() << ", expected 1) for operation " << opType; auto inputType = operands[inputIndexes[0]].type; Version version; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_OC_MR1; inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { version = Version::ANDROID_OC_MR1; inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } if (inputIndexes.size() == 3) { inExpectedTypes.push_back(OperandType::BOOL); version = combineVersions(version, Version::ANDROID_Q); } else { version = combineVersions(version, Version::ANDROID_OC_MR1); } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::SPACE_TO_DEPTH: { NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && outputIndexes.size() == 1) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 3 or 2) or output operands (" << outputIndexes.size() << ", expected 1) for operation " << opType; auto inputType = operands[inputIndexes[0]].type; Version version; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_OC_MR1; inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { version = Version::ANDROID_OC_MR1; inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } if (inputIndexes.size() == 3) { inExpectedTypes.push_back(OperandType::BOOL); version = combineVersions(version, Version::ANDROID_Q); } else { version = combineVersions(version, Version::ANDROID_OC_MR1); } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::EMBEDDING_LOOKUP: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); auto inputType = operands[inputIndexes[1]].type; NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Unsupported input tensor type for operation " << opType; Version version; std::vector inExpectedTypes = {OperandType::TENSOR_INT32, inputType}; std::vector outExpectedTypes = {inputType}; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else if (inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { version = Version::ANDROID_Q; } else { version = Version::ANDROID_OC_MR1; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::HASHTABLE_LOOKUP: { NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 2) << invalidInOutNumberMessage(3, 2); auto inputType = operands[inputIndexes[2]].type; NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) << "Unsupported input tensor type for operation " << opType; std::vector inExpectedTypes = {OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, inputType}; std::vector outExpectedTypes = {inputType, OperandType::TENSOR_QUANT8_ASYMM}; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return Version::ANDROID_OC_MR1; } case OperationType::LSH_PROJECTION: { NN_VALIDATE(inputIndexes.size() == 4 && outputIndexes.size() == 1) << invalidInOutNumberMessage(4, 1); auto inputType = operands[inputIndexes[1]].type; NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) << "Unsupported input tensor type for operation " << opType; auto hashType = operands[inputIndexes[0]].type; Version version; std::vector inExpectedTypes; if (hashType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, inputType, OperandType::TENSOR_FLOAT16, OperandType::INT32, }; } else if (hashType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_FLOAT32, inputType, OperandType::TENSOR_FLOAT32, OperandType::INT32, }; } else { NN_VALIDATE_FAIL() << "Unsupported hash tensor type for operation " << opType; } std::vector outExpectedTypes = {OperandType::TENSOR_INT32}; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::BIDIRECTIONAL_SEQUENCE_LSTM: { const uint32_t kNumOutputs = 2; const uint32_t kNumOutputsMerged = 1; const uint32_t kNumOutputsWithState = 6; const uint32_t kNumOutputsMergedWithState = 5; NN_VALIDATE(inputIndexes.size() == 61 && (outputIndexes.size() == kNumOutputs || outputIndexes.size() == kNumOutputsMerged || outputIndexes.size() == kNumOutputsWithState || outputIndexes.size() == kNumOutputsMergedWithState)) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 61) or output operands (" << outputIndexes.size() << ", expected 1, 2, 5 or 6) for operation " << opType; std::vector inExpectedTypes; auto inputType = operands[inputIndexes[0]].type; NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_FLOAT16) << "Unsupported input tensor type for operation " << opType; inExpectedTypes = {}; for (int i = 0; i < 48; ++i) { inExpectedTypes.push_back(inputType); } inExpectedTypes.push_back(OperandType::INT32); inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32 ? OperandType::FLOAT32 : OperandType::FLOAT16); inExpectedTypes.push_back(inputType == OperandType::TENSOR_FLOAT32 ? OperandType::FLOAT32 : OperandType::FLOAT16); inExpectedTypes.push_back(OperandType::BOOL); inExpectedTypes.push_back(OperandType::BOOL); for (int i = 0; i < 8; ++i) { inExpectedTypes.push_back(inputType); } Version version = Version::ANDROID_Q; if (outputIndexes.size() == kNumOutputsWithState || outputIndexes.size() == kNumOutputsMergedWithState) { version = Version::ANDROID_R; } std::vector outExpectedTypes(outputIndexes.size(), inputType); NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::LSTM: { NN_VALIDATE((inputIndexes.size() == 23 || inputIndexes.size() == 27) && outputIndexes.size() == 4) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 23 or 27) or output operands (" << outputIndexes.size() << ", expected 4) for operation " << opType; std::vector inExpectedTypes; std::vector outExpectedTypes; auto inputType = operands[inputIndexes[0]].type; NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_FLOAT16) << "Unsupported input tensor type for operation " << opType; Version version = Version::ANDROID_OC_MR1; inExpectedTypes = {inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, inputType, OperandType::INT32}; if (inputType == OperandType::TENSOR_FLOAT32) { inExpectedTypes.push_back(OperandType::FLOAT32); inExpectedTypes.push_back(OperandType::FLOAT32); } else { version = Version::ANDROID_Q; inExpectedTypes.push_back(OperandType::FLOAT16); inExpectedTypes.push_back(OperandType::FLOAT16); } outExpectedTypes = {inputType, inputType, inputType, inputType}; if (inputIndexes.size() == 23) { version = combineVersions(version, Version::ANDROID_OC_MR1); } else { version = combineVersions(version, Version::ANDROID_Q); for (int i = 0; i < 4; ++i) { inExpectedTypes.push_back(inputType); } } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::QUANTIZED_16BIT_LSTM: { NN_VALIDATE(inputIndexes.size() == 15 && outputIndexes.size() == 2) << invalidInOutNumberMessage(15, 2); std::vector inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, OperandType::TENSOR_QUANT16_SYMM, OperandType::TENSOR_QUANT8_ASYMM}; std::vector outExpectedTypes = {OperandType::TENSOR_QUANT16_SYMM, OperandType::TENSOR_QUANT8_ASYMM}; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return Version::ANDROID_Q; } case OperationType::RANDOM_MULTINOMIAL: { NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) << invalidInOutNumberMessage(3, 1); OperandType inputType = operands[inputIndexes[0]].type; std::vector inExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_FLOAT16) { inExpectedTypes = {inputType, OperandType::INT32, OperandType::TENSOR_INT32}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } std::vector outExpectedTypes = {OperandType::TENSOR_INT32}; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return Version::ANDROID_Q; } case OperationType::RNN: { NN_VALIDATE(inputIndexes.size() == 6 && outputIndexes.size() == 2) << invalidInOutNumberMessage(6, 2); OperandType inputType = operands[inputIndexes[0]].type; Version version; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_OC_MR1; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::INT32, }; outExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, }; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::INT32, }; outExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, }; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::SVDF: { NN_VALIDATE(inputIndexes.size() == 7 && outputIndexes.size() == 2) << invalidInOutNumberMessage(7, 2); Version version; OperandType inputType = operands[inputIndexes[0]].type; if (inputType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_OC_MR1; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } std::vector inExpectedTypes = { inputType, inputType, inputType, inputType, inputType, OperandType::INT32, OperandType::INT32, }; std::vector outExpectedTypes = {inputType, inputType}; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::BATCH_TO_SPACE_ND: { NN_VALIDATE((inputIndexes.size() == 3 || inputIndexes.size() == 2) && outputIndexes.size() == 1) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 3 or 2) or output operands (" << outputIndexes.size() << ", expected 1) for operation " << opType; auto inputType = operands[inputIndexes[0]].type; Version version = Version::ANDROID_OC_MR1; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } if (inputIndexes.size() == 3) { inExpectedTypes.push_back(OperandType::BOOL); version = combineVersions(version, Version::ANDROID_Q); } else { version = combineVersions(version, Version::ANDROID_P); } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::SPACE_TO_BATCH_ND: { NN_VALIDATE((inputIndexes.size() == 4 || inputIndexes.size() == 3) && outputIndexes.size() == 1) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 4 or 3) or output operands (" << outputIndexes.size() << ", expected 1) for operation " << opType; auto inputType = operands[inputIndexes[0]].type; Version version = Version::ANDROID_OC_MR1; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) { if (operands[inputIndexes[0]].zeroPoint != 0) { version = Version::ANDROID_Q; } inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; inExpectedTypes = { OperandType::TENSOR_QUANT8_ASYMM_SIGNED, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM_SIGNED}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } if (inputIndexes.size() == 4) { inExpectedTypes.push_back(OperandType::BOOL); version = combineVersions(version, Version::ANDROID_Q); } else { version = combineVersions(version, Version::ANDROID_P); } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::PAD: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); auto inputType = operands[inputIndexes[0]].type; Version version; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_P; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { if (operands[inputIndexes[0]].zeroPoint == 0) { version = Version::ANDROID_P; } else { version = Version::ANDROID_Q; } } inExpectedTypes = { inputType, OperandType::TENSOR_INT32, }; outExpectedTypes = {inputType}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } const auto inputRank = operands[inputIndexes[0]].dimensions.size(); NN_VALIDATE_LE(inputRank, 4u) << "Unsupported input tensor rank for operation " << opType; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::PAD_V2: { NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) << invalidInOutNumberMessage(3, 1); auto inputType = operands[inputIndexes[0]].type; Version version; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { version = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT32, OperandType::TENSOR_INT32, OperandType::FLOAT32, }; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; inExpectedTypes = { OperandType::TENSOR_FLOAT16, OperandType::TENSOR_INT32, OperandType::FLOAT16, }; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { version = Version::ANDROID_Q; } inExpectedTypes = { inputType, OperandType::TENSOR_INT32, OperandType::INT32, }; // TODO(b/116699425): Make it UINT8. outExpectedTypes = {inputType}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } const auto inputRank = operands[inputIndexes[0]].dimensions.size(); NN_VALIDATE_LE(inputRank, 4u) << "Unsupported input tensor rank for operation " << opType; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::CAST: { NN_VALIDATE(inputIndexes.size() == 1 && outputIndexes.size() == 1) << invalidInOutNumberMessage(1, 1); auto inputOperand = operands[inputIndexes[0]]; auto outputOperand = operands[outputIndexes[0]]; auto inputType = inputOperand.type; auto outputType = outputOperand.type; Version version; std::vector inExpectedTypes; std::vector outExpectedTypes; if ((inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) && (outputType == OperandType::TENSOR_FLOAT16 || outputType == OperandType::TENSOR_FLOAT32 || outputType == OperandType::TENSOR_INT32 || outputType == OperandType::TENSOR_QUANT8_ASYMM)) { version = Version::ANDROID_Q; inExpectedTypes = {inputType}; outExpectedTypes = {outputType}; } else if (inputType == OperandType::TENSOR_BOOL8 || inputType == OperandType::TENSOR_QUANT16_ASYMM || inputType == OperandType::TENSOR_QUANT16_SYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED || inputType == OperandType::TENSOR_QUANT8_SYMM) { version = Version::ANDROID_R; inExpectedTypes = {inputType}; outExpectedTypes = {inputType}; // Only identity CAST is supported. } else { NN_VALIDATE_FAIL() << "Unsupported data type for operation " << opType; } // Validate that output shape is equal to input shape if dimensions // are already known. auto getNumberOfElements = [](const std::vector& dims) { if (dims.empty()) { return 0; } return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); }; NN_VALIDATE(inputOperand.dimensions.empty() || outputOperand.dimensions.empty() || getNumberOfElements(outputOperand.dimensions) == 0 || inputOperand.dimensions == outputOperand.dimensions); NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::MEAN: { NN_VALIDATE(inputIndexes.size() == 3 && outputIndexes.size() == 1) << invalidInOutNumberMessage(3, 1); const auto inputRank = operands[inputIndexes[0]].dimensions.size(); NN_VALIDATE_LE(inputRank, 4u) << "Unsupported input tensor rank for operation " << opType; auto inputType = operands[inputIndexes[0]].type; Version version; if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) { version = Version::ANDROID_P; } else if (inputType == OperandType::TENSOR_FLOAT16) { version = Version::ANDROID_Q; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } std::vector inExpectedTypes = {inputType, OperandType::TENSOR_INT32, OperandType::INT32}; std::vector outExpectedTypes = {inputType}; NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::ARGMAX: case OperationType::ARGMIN: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); auto inputType = operands[inputIndexes[0]].type; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { inExpectedTypes = {inputType, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_INT32}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return Version::ANDROID_Q; } case OperationType::EXPAND_DIMS: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); auto inputType = operands[inputIndexes[0]].type; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { inExpectedTypes = {inputType, OperandType::INT32}; outExpectedTypes = {inputType}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } Version version; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { version = Version::ANDROID_Q; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::SPLIT: { NN_VALIDATE_EQ(inputIndexes.size(), 3u) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 3)" << opType; auto inputType = operands[inputIndexes[0]].type; NN_VALIDATE(inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) << "Unsupported input tensor type for operation " << opType; Version version; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { version = Version::ANDROID_Q; } std::vector inExpectedTypes = {inputType, OperandType::INT32, OperandType::INT32}; std::vector outExpectedTypes(outputIndexes.size(), inputType); NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::MAXIMUM: case OperationType::MINIMUM: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); std::vector inExpectedTypes; std::vector outExpectedTypes; OperandType inputType = operands[inputIndexes[0]].type; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { inExpectedTypes = {inputType, inputType}; outExpectedTypes = {inputType}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } Version version; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { version = Version::ANDROID_Q; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::GROUPED_CONV_2D: { NN_VALIDATE((inputIndexes.size() == 12 || inputIndexes.size() == 9) && outputIndexes.size() == 1) << "Invalid number of input operands (" << inputIndexes.size() << ", expected 12 or 9) or output operands (" << outputIndexes.size() << ", expected 1) for operation " << opType; auto inputType = operands[inputIndexes[0]].type; auto filterType = operands[inputIndexes[1]].type; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { NN_VALIDATE(filterType == inputType || filterType == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) << "Unsupported filter tensor type for operation " << opType; NN_VALIDATE(filterType != OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL || std::get( operands[inputIndexes[1]].extraParams) .channelDim == 0) << "Unsupported filter tensor channel dimension for operation " << opType; inExpectedTypes = { inputType, filterType, OperandType::TENSOR_INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32, OperandType::INT32}; outExpectedTypes = {inputType}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } if (inputIndexes.size() == 12) { std::vector explicitScalarTypes(3, OperandType::INT32); inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(), explicitScalarTypes.end()); } inExpectedTypes.push_back(OperandType::BOOL); Version version; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { version = Version::ANDROID_Q; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::TILE: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); auto inputType = operands[inputIndexes[0]].type; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM || inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { inExpectedTypes = {inputType, OperandType::TENSOR_INT32}; outExpectedTypes = {inputType}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } Version version; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { version = Version::ANDROID_Q; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::POW: { NN_VALIDATE(inputIndexes.size() == 2 && outputIndexes.size() == 1) << invalidInOutNumberMessage(2, 1); auto inputType = operands[inputIndexes[0]].type; std::vector inExpectedTypes; std::vector outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32) { inExpectedTypes = {inputType, inputType}; outExpectedTypes = {inputType}; } else { NN_VALIDATE_FAIL() << "Unsupported input tensor type for operation " << opType; } Version version; if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) { version = Version::ANDROID_R; } else { version = Version::ANDROID_Q; } NN_TRY(validateOperationOperandTypes(operands, inputIndexes, inExpectedTypes, outputIndexes, outExpectedTypes)); return version; } case OperationType::IF: { return validateIfOperation(inputIndexes, outputIndexes, operands, subgraphs); } case OperationType::WHILE: { return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs); } default: { const OperationRegistration* operationRegistration = BuiltinOperationResolver::get()->findOperation( static_cast(opType)); // TODO: return ErrorStatus::UNEXPECTED_NULL NN_VALIDATE(operationRegistration != nullptr) << opType << " not registered"; // TODO: return ErrorStatus::UNEXPECTED_NULL NN_VALIDATE(operationRegistration->validate != nullptr) << "Incomplete operation registration: " << opType; OperationValidationContext context(operationRegistration->name, inputIndexes, outputIndexes, operands); return operationRegistration->validate(&context); } } } Result validateOperationIncludingOperandVersions( const Operation& operation, const std::vector& operands, const std::vector& operandVersions, const std::vector& subgraphs) { auto version = NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs)); for (uint32_t index : operation.inputs) { version = combineVersions(version, operandVersions[index]); } for (uint32_t index : operation.outputs) { version = combineVersions(version, operandVersions[index]); } return version; } } // anonymous namespace // Below this point are all the functions that are declared in Validation.h. The convention of this // file is to keep the function bodies of the functions declared in Validation.h minimal, meaning // that most functions below simply redirect to one of the functions defined above in the anonymous // namespace. If there is a function name clash between one of the functions below and one of the // functions above, the function in the anonymous namespace is appended with "Impl". Version combineVersions(Version lhs, Version rhs) { return std::max(lhs, rhs); } Result validate(const DeviceStatus& deviceStatus) { return validateDeviceStatus(deviceStatus); } Result validate(const ExecutionPreference& executionPreference) { return validateExecutionPreference(executionPreference); } Result validate(const DeviceType& deviceType) { return validateDeviceType(deviceType); } Result validate(const MeasureTiming& measureTiming) { return validateMeasureTiming(measureTiming); } Result validate(const OperandType& operandType) { return validateOperandType(operandType); } Result validate(const Priority& priority) { return validatePriority(priority); } Result validate(const ErrorStatus& errorStatus) { return validateErrorStatus(errorStatus); } Result validate(const FusedActivationFunc& activation) { return validateFusedActivationFunc(activation); } Result validate(const OutputShape& outputShape) { return validateOutputShape(outputShape); } Result validate(const Timing& timing) { return validateTiming(timing); } Result validate(const Capabilities& capabilities) { return validateCapabilities(capabilities); } Result validate(const Extension& extension) { return validateExtension(extension); } Result validate(const SharedHandle& handle) { return validateSharedHandle(handle); } Result validate(const SharedMemory& memory) { return validateSharedMemory(memory); } Result validate(const Model& model) { return validateModel(model); } Result validate(const BufferDesc& bufferDesc) { return validateBufferDesc(bufferDesc); } Result validate(const BufferRole& bufferRole) { return validateBufferRole(bufferRole); } Result validate(const Request& request) { return validateRequest(request); } Result validate(const OptionalTimePoint& optionalTimePoint) { return validateOptionalTimePoint(optionalTimePoint); } Result validate(const OptionalDuration& optionalTimeoutDuration) { return validateOptionalTimeoutDuration(optionalTimeoutDuration); } Result validate(const CacheToken& cacheToken) { return validateCacheToken(cacheToken); } Result validate(const SyncFence& syncFence) { return validateSyncFence(syncFence); } Result validate(const std::vector& outputShapes) { return validateVector(outputShapes, validateOutputShape); } Result validate(const std::vector& extensions) { return validateExtensions(extensions); } Result validate(const std::vector& handles) { return validateVector(handles, validateSharedHandle); } Result validate(const std::vector& bufferRoles) { return validateVector(bufferRoles, validateBufferRole); } Result validate(const std::vector& syncFences) { return validateVector(syncFences, validateSyncFence); } Result validateRequestForModel(const Request& request, const Model& model, bool allowUnspecifiedOutput) { return validateRequestForModelImpl(request, model, allowUnspecifiedOutput); } Result validateMemoryDesc( const BufferDesc& desc, const std::vector& preparedModels, const std::vector& inputRoles, const std::vector& outputRoles, const std::function& getModel, std::set* preparedModelRoles, Operand* combinedOperand) { return validateMemoryDescImpl(desc, preparedModels, inputRoles, outputRoles, getModel, preparedModelRoles, combinedOperand); } Result validateOperandSymmPerChannelQuantParams( const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, const char* tag) { return validateOperandSymmPerChannelQuantParamsImpl(operand, channelQuant, tag); } Result validateOperandType(const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo, const char* tag, bool allowPartial) { return validateOperandTypeImpl(type, extensionOperandTypeInfo, tag, allowPartial); } Result validateOperandList(const std::vector& list, size_t operandCount, const char* tag) { return validateOperandListImpl(list, operandCount, tag); } Result validateOperationButNotOperands(const Operation& operation, const std::vector& operands, const std::vector& subgraphs) { NN_TRY(validateOperationButNotOperandsImpl(operation, operands, subgraphs)); return {}; } struct SubgraphVersionCache { std::vector> cache; }; std::unique_ptr createSubgraphVersionCache( size_t referencedSubgraphCount) { auto subgraphVersionCache = std::make_unique(); subgraphVersionCache->cache.resize(referencedSubgraphCount); constexpr auto deleter = [](SubgraphVersionCache* ptr) { delete ptr; }; return {subgraphVersionCache.release(), deleter}; } Result validateOperationAndAnythingItDependsOn( const Operation& operation, const std::vector& operands, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& subgraphs, SubgraphVersionCache* subgraphVersionCache) { CHECK(subgraphVersionCache != nullptr); std::vector operandVersions(operands.size(), Version::ANDROID_OC_MR1); for (uint32_t index : operation.inputs) { NN_VALIDATE_LT(index, operands.size()); const Operand& operand = operands[index]; operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn( operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache)); } for (uint32_t index : operation.outputs) { NN_VALIDATE_LT(index, operands.size()); const Operand& operand = operands[index]; operandVersions[index] = NN_TRY(validateOperandAndAnythingItDependsOn( operand, operandValuesSize, poolSizes, subgraphs, subgraphVersionCache)); } return validateOperationIncludingOperandVersions(operation, operands, operandVersions, subgraphs); } Result validateOperandAndAnythingItDependsOn(const Operand& operand, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& subgraphs, SubgraphVersionCache* subgraphVersionCache) { CHECK(subgraphVersionCache != nullptr); return validateOperand(operand, operandValuesSize, poolSizes, subgraphs, &subgraphVersionCache->cache); } } // namespace android::nn