1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H 18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H 19 20 #include <utility> 21 #include <variant> 22 #include <vector> 23 24 #include "NeuralNetworksExtensions.h" 25 #include "NeuralNetworksWrapper.h" 26 27 namespace android { 28 namespace nn { 29 namespace extension_wrapper { 30 31 using wrapper::SymmPerChannelQuantParams; 32 using wrapper::Type; 33 34 struct ExtensionOperandParams { 35 std::vector<uint8_t> data; 36 ExtensionOperandParamsExtensionOperandParams37 ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {} 38 39 template <typename T> ExtensionOperandParamsExtensionOperandParams40 ExtensionOperandParams(const T& data) 41 : ExtensionOperandParams( 42 std::vector(reinterpret_cast<const uint8_t*>(&data), 43 reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) { 44 static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable"); 45 } 46 }; 47 48 struct OperandType { 49 using ExtraParams = 50 std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>; 51 52 ANeuralNetworksOperandType operandType; 53 std::vector<uint32_t> dimensions; 54 ExtraParams extraParams; 55 OperandTypeOperandType56 OperandType(const OperandType& other) 57 : operandType(other.operandType), 58 dimensions(other.dimensions), 59 extraParams(other.extraParams) { 60 operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; 61 } 62 63 OperandType& operator=(const OperandType& other) { 64 if (this != &other) { 65 operandType = other.operandType; 66 dimensions = other.dimensions; 67 extraParams = other.extraParams; 68 operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; 69 } 70 return *this; 71 } 72 73 OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0, 74 ExtraParams&& extraParams = std::monostate()) dimensionsOperandType75 : dimensions(std::move(d)), extraParams(std::move(extraParams)) { 76 operandType = { 77 .type = static_cast<int32_t>(type), 78 .dimensionCount = static_cast<uint32_t>(dimensions.size()), 79 .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr, 80 .scale = scale, 81 .zeroPoint = zeroPoint, 82 }; 83 } 84 OperandTypeOperandType85 OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint, 86 SymmPerChannelQuantParams&& channelQuant) 87 : OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {} 88 OperandTypeOperandType89 OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams) 90 : OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {} 91 }; 92 93 class Model : public wrapper::Model { 94 public: 95 using wrapper::Model::Model; // Inherit constructors. 96 getExtensionOperandType(const char * extensionName,uint16_t typeWithinExtension)97 int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) { 98 int32_t result; 99 if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension, 100 &result) != ANEURALNETWORKS_NO_ERROR) { 101 mValid = false; 102 } 103 return result; 104 } 105 getExtensionOperationType(const char * extensionName,uint16_t typeWithinExtension)106 ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName, 107 uint16_t typeWithinExtension) { 108 ANeuralNetworksOperationType result; 109 if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName, 110 typeWithinExtension, 111 &result) != ANEURALNETWORKS_NO_ERROR) { 112 mValid = false; 113 } 114 return result; 115 } 116 addOperand(const OperandType * type)117 uint32_t addOperand(const OperandType* type) { 118 if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 119 ANEURALNETWORKS_NO_ERROR) { 120 mValid = false; 121 } 122 if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) { 123 const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams); 124 if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams( 125 mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) { 126 mValid = false; 127 } 128 } else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) { 129 const auto& extension = std::get<ExtensionOperandParams>(type->extraParams); 130 if (ANeuralNetworksModel_setOperandExtensionData( 131 mModel, mNextOperandId, extension.data.data(), extension.data.size()) != 132 ANEURALNETWORKS_NO_ERROR) { 133 mValid = false; 134 } 135 } 136 return mNextOperandId++; 137 } 138 }; 139 140 } // namespace extension_wrapper 141 142 namespace wrapper { 143 144 using ExtensionModel = extension_wrapper::Model; 145 using ExtensionOperandType = extension_wrapper::OperandType; 146 using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams; 147 148 } // namespace wrapper 149 } // namespace nn 150 } // namespace android 151 152 #endif // ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H 153