/* * Copyright (C) 2020 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H #define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H #include #include #include #include #include "nnapi/Result.h" #include "nnapi/Types.h" namespace android::nn { // Utility functions Version combineVersions(Version lhs, Version rhs); Result validate(const DeviceStatus& deviceStatus); Result validate(const ExecutionPreference& executionPreference); Result validate(const DeviceType& deviceType); Result validate(const MeasureTiming& measureTiming); Result validate(const OperandType& operandType); Result validate(const Priority& priority); Result validate(const ErrorStatus& errorStatus); Result validate(const FusedActivationFunc& activation); Result validate(const OutputShape& outputShape); Result validate(const Timing& timing); Result validate(const Capabilities& capabilities); Result validate(const Extension& extension); Result validate(const SharedHandle& handle); Result validate(const SharedMemory& memory); Result validate(const Model& model); Result validate(const BufferDesc& bufferDesc); Result validate(const BufferRole& bufferRole); Result validate(const Request& request); Result validate(const OptionalTimePoint& optionalTimePoint); Result validate(const OptionalDuration& optionalTimeoutDuration); Result validate(const CacheToken& cacheToken); Result validate(const SyncFence& syncFence); Result validate(const std::vector& outputShapes); Result validate(const std::vector& extensions); Result validate(const std::vector& handles); Result validate(const std::vector& bufferRoles); Result validate(const std::vector& syncFences); // Validate request applied to model. Result validateRequestForModel(const Request& request, const Model& model, bool allowUnspecifiedOutput = true); // Validate memory descriptor. enum class IOType { INPUT, OUTPUT }; using PreparedModelRole = std::tuple; // Verifies that the input arguments to IDevice::allocate are valid. // Optionally, this function can return a flattened prepared model roles and a combined operand. // Pass nullptr if either value is not needed. // IMPORTANT: This function cannot validate dimensions and extraParams with extension operand type. // Each driver should do their own validation of extension type dimensions and extraParams. Result validateMemoryDesc( const BufferDesc& desc, const std::vector& preparedModels, const std::vector& inputRoles, const std::vector& outputRoles, const std::function& getModel, std::set* preparedModelRoles, Operand* combinedOperand); Result validateOperandSymmPerChannelQuantParams( const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, const char* tag); // Validates an operand type. // // extensionOperandTypeInfo must be nullptr iff the type is not an extension type. // // If allowPartial is true, the dimensions may be underspecified. Result validateOperandType(const Operand& type, const Extension::OperandTypeInformation* extensionOperandTypeInfo, const char* tag, bool allowPartial); Result validateOperandList(const std::vector& list, size_t operandCount, const char* tag); // Validates the operation, and ensures it uses subgraphs in a valid way, but does not validate any // subgraphs or operands themselves. // // This function is currently used by ModelBuilder. Result validateOperationButNotOperands(const Operation& operation, const std::vector& operands, const std::vector& subgraphs); // Forward declaration for a utility class for caching a referenced subgraph's version. struct SubgraphVersionCache; // Function to create an opaque handle to a utility class for caching a referenced subgraph's // version. std::unique_ptr createSubgraphVersionCache( size_t subgraphCount); // Validate the operation or operand, also validating any subgraphs and operands it may use, // recursively. // // `subgraphVersionCache` is used to cache validation information for `subgraphs`, which would // otherwise be unnecessarily re-validated. For this reason, `subgraphVersionCache` must be non-null // and must have been created with the number of referenced subgraphs in `subgraphs`. The provided // subgraphs must not form a reference cycle. // // These functions are currently used by MetaModel. Result validateOperationAndAnythingItDependsOn( const Operation& operation, const std::vector& operands, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& subgraphs, SubgraphVersionCache* subgraphVersionCache); Result validateOperandAndAnythingItDependsOn(const Operand& operand, size_t operandValuesSize, const std::vector& poolSizes, const std::vector& subgraphs, SubgraphVersionCache* subgraphVersionCache); } // namespace android::nn #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H