1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H 18 #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H 19 20 #include "NeuralNetworksExtensions.h" 21 #include "NeuralNetworksWrapper.h" 22 23 #include <variant> 24 25 namespace android { 26 namespace nn { 27 namespace extension_wrapper { 28 29 using wrapper::SymmPerChannelQuantParams; 30 using wrapper::Type; 31 32 struct ExtensionOperandParams { 33 std::vector<uint8_t> data; 34 ExtensionOperandParamsExtensionOperandParams35 ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {} 36 37 template <typename T> ExtensionOperandParamsExtensionOperandParams38 ExtensionOperandParams(const T& data) 39 : ExtensionOperandParams( 40 std::vector(reinterpret_cast<const uint8_t*>(&data), 41 reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) { 42 static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable"); 43 } 44 }; 45 46 struct OperandType { 47 using ExtraParams = 48 std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>; 49 50 ANeuralNetworksOperandType operandType; 51 std::vector<uint32_t> dimensions; 52 ExtraParams extraParams; 53 OperandTypeOperandType54 OperandType(const OperandType& other) 55 : operandType(other.operandType), 56 dimensions(other.dimensions), 57 extraParams(other.extraParams) { 58 operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; 59 } 60 61 OperandType& operator=(const OperandType& other) { 62 if (this != &other) { 63 operandType = other.operandType; 64 dimensions = other.dimensions; 65 extraParams = other.extraParams; 66 operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; 67 } 68 return *this; 69 } 70 71 OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0, 72 ExtraParams&& extraParams = std::monostate()) dimensionsOperandType73 : dimensions(std::move(d)), extraParams(std::move(extraParams)) { 74 operandType = { 75 .type = static_cast<int32_t>(type), 76 .dimensionCount = static_cast<uint32_t>(dimensions.size()), 77 .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr, 78 .scale = scale, 79 .zeroPoint = zeroPoint, 80 }; 81 } 82 OperandTypeOperandType83 OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint, 84 SymmPerChannelQuantParams&& channelQuant) 85 : OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {} 86 OperandTypeOperandType87 OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams) 88 : OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {} 89 }; 90 91 class Model : public wrapper::Model { 92 public: 93 using wrapper::Model::Model; // Inherit constructors. 94 getExtensionOperandType(const char * extensionName,uint16_t typeWithinExtension)95 int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) { 96 int32_t result; 97 if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension, 98 &result) != ANEURALNETWORKS_NO_ERROR) { 99 mValid = false; 100 } 101 return result; 102 } 103 getExtensionOperationType(const char * extensionName,uint16_t typeWithinExtension)104 ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName, 105 uint16_t typeWithinExtension) { 106 ANeuralNetworksOperationType result; 107 if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName, 108 typeWithinExtension, 109 &result) != ANEURALNETWORKS_NO_ERROR) { 110 mValid = false; 111 } 112 return result; 113 } 114 addOperand(const OperandType * type)115 uint32_t addOperand(const OperandType* type) { 116 if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != 117 ANEURALNETWORKS_NO_ERROR) { 118 mValid = false; 119 } 120 if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) { 121 const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams); 122 if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams( 123 mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) { 124 mValid = false; 125 } 126 } else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) { 127 const auto& extension = std::get<ExtensionOperandParams>(type->extraParams); 128 if (ANeuralNetworksModel_setOperandExtensionData( 129 mModel, mNextOperandId, extension.data.data(), extension.data.size()) != 130 ANEURALNETWORKS_NO_ERROR) { 131 mValid = false; 132 } 133 } 134 return mNextOperandId++; 135 } 136 }; 137 138 } // namespace extension_wrapper 139 140 namespace wrapper { 141 142 using ExtensionModel = extension_wrapper::Model; 143 using ExtensionOperandType = extension_wrapper::OperandType; 144 using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams; 145 146 } // namespace wrapper 147 } // namespace nn 148 } // namespace android 149 150 #endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H 151