1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H 19 20 #include <memory> 21 #include <set> 22 #include <tuple> 23 #include <vector> 24 25 #include "nnapi/Result.h" 26 #include "nnapi/Types.h" 27 28 namespace android::nn { 29 30 // Utility functions 31 32 Version combineVersions(Version lhs, Version rhs); 33 34 Result<Version> validate(const DeviceStatus& deviceStatus); 35 Result<Version> validate(const ExecutionPreference& executionPreference); 36 Result<Version> validate(const DeviceType& deviceType); 37 Result<Version> validate(const MeasureTiming& measureTiming); 38 Result<Version> validate(const OperandType& operandType); 39 Result<Version> validate(const Priority& priority); 40 Result<Version> validate(const ErrorStatus& errorStatus); 41 Result<Version> validate(const FusedActivationFunc& activation); 42 Result<Version> validate(const OutputShape& outputShape); 43 Result<Version> validate(const Timing& timing); 44 Result<Version> validate(const Capabilities& capabilities); 45 Result<Version> validate(const Extension& extension); 46 Result<Version> validate(const SharedHandle& handle); 47 Result<Version> validate(const SharedMemory& memory); 48 Result<Version> validate(const Model& model); 49 Result<Version> validate(const BufferDesc& bufferDesc); 50 Result<Version> validate(const BufferRole& bufferRole); 51 Result<Version> validate(const Request& request); 52 Result<Version> validate(const OptionalTimePoint& optionalTimePoint); 53 Result<Version> validate(const OptionalDuration& optionalTimeoutDuration); 54 Result<Version> validate(const CacheToken& cacheToken); 55 Result<Version> validate(const SyncFence& syncFence); 56 57 Result<Version> validate(const std::vector<OutputShape>& outputShapes); 58 Result<Version> validate(const std::vector<Extension>& extensions); 59 Result<Version> validate(const std::vector<SharedHandle>& handles); 60 Result<Version> validate(const std::vector<BufferRole>& bufferRoles); 61 Result<Version> validate(const std::vector<SyncFence>& syncFences); 62 63 // Validate request applied to model. 64 Result<Version> validateRequestForModel(const Request& request, const Model& model, 65 bool allowUnspecifiedOutput = true); 66 67 // Validate memory descriptor. 68 enum class IOType { INPUT, OUTPUT }; 69 using PreparedModelRole = std::tuple<const IPreparedModel*, IOType, uint32_t>; 70 71 // Verifies that the input arguments to IDevice::allocate are valid. 72 // Optionally, this function can return a flattened prepared model roles and a combined operand. 73 // Pass nullptr if either value is not needed. 74 // IMPORTANT: This function cannot validate dimensions and extraParams with extension operand type. 75 // Each driver should do their own validation of extension type dimensions and extraParams. 76 Result<Version> validateMemoryDesc( 77 const BufferDesc& desc, const std::vector<SharedPreparedModel>& preparedModels, 78 const std::vector<BufferRole>& inputRoles, const std::vector<BufferRole>& outputRoles, 79 const std::function<const Model*(const SharedPreparedModel&)>& getModel, 80 std::set<PreparedModelRole>* preparedModelRoles, Operand* combinedOperand); 81 82 Result<void> validateOperandSymmPerChannelQuantParams( 83 const Operand& operand, const Operand::SymmPerChannelQuantParams& channelQuant, 84 const char* tag); 85 86 // Validates an operand type. 87 // 88 // extensionOperandTypeInfo must be nullptr iff the type is not an extension type. 89 // 90 // If allowPartial is true, the dimensions may be underspecified. 91 Result<void> validateOperandType(const Operand& type, 92 const Extension::OperandTypeInformation* extensionOperandTypeInfo, 93 const char* tag, bool allowPartial); 94 Result<void> validateOperandList(const std::vector<uint32_t>& list, size_t operandCount, 95 const char* tag); 96 97 // Validates the operation, and ensures it uses subgraphs in a valid way, but does not validate any 98 // subgraphs or operands themselves. 99 // 100 // This function is currently used by ModelBuilder. 101 Result<void> validateOperationButNotOperands(const Operation& operation, 102 const std::vector<Operand>& operands, 103 const std::vector<Model::Subgraph>& subgraphs); 104 105 // Forward declaration for a utility class for caching a referenced subgraph's version. 106 struct SubgraphVersionCache; 107 108 // Function to create an opaque handle to a utility class for caching a referenced subgraph's 109 // version. 110 std::unique_ptr<SubgraphVersionCache, void (*)(SubgraphVersionCache*)> createSubgraphVersionCache( 111 size_t subgraphCount); 112 113 // Validate the operation or operand, also validating any subgraphs and operands it may use, 114 // recursively. 115 // 116 // `subgraphVersionCache` is used to cache validation information for `subgraphs`, which would 117 // otherwise be unnecessarily re-validated. For this reason, `subgraphVersionCache` must be non-null 118 // and must have been created with the number of referenced subgraphs in `subgraphs`. The provided 119 // subgraphs must not form a reference cycle. 120 // 121 // These functions are currently used by MetaModel. 122 Result<Version> validateOperationAndAnythingItDependsOn( 123 const Operation& operation, const std::vector<Operand>& operands, size_t operandValuesSize, 124 const std::vector<size_t>& poolSizes, const std::vector<Model::Subgraph>& subgraphs, 125 SubgraphVersionCache* subgraphVersionCache); 126 Result<Version> validateOperandAndAnythingItDependsOn(const Operand& operand, 127 size_t operandValuesSize, 128 const std::vector<size_t>& poolSizes, 129 const std::vector<Model::Subgraph>& subgraphs, 130 SubgraphVersionCache* subgraphVersionCache); 131 132 } // namespace android::nn 133 134 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_VALIDATION_H 135