/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H #include "NeuralNetworksExtensions.h" #include "NeuralNetworksWrapper.h" #include namespace android { namespace nn { namespace extension_wrapper { using wrapper::SymmPerChannelQuantParams; using wrapper::Type; struct ExtensionOperandParams { std::vector data; ExtensionOperandParams(std::vector data) : data(std::move(data)) {} template ExtensionOperandParams(const T& data) : ExtensionOperandParams( std::vector(reinterpret_cast(&data), reinterpret_cast(&data) + sizeof(data))) { static_assert(std::is_trivially_copyable::value, "data must be trivially copyable"); } }; struct OperandType { using ExtraParams = std::variant; ANeuralNetworksOperandType operandType; std::vector dimensions; ExtraParams extraParams; OperandType(const OperandType& other) : operandType(other.operandType), dimensions(other.dimensions), extraParams(other.extraParams) { operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; } OperandType& operator=(const OperandType& other) { if (this != &other) { operandType = other.operandType; dimensions = other.dimensions; extraParams = other.extraParams; operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; } return *this; } OperandType(Type type, std::vector d, float scale = 0.0f, int32_t zeroPoint = 0, ExtraParams&& extraParams = std::monostate()) : dimensions(std::move(d)), extraParams(std::move(extraParams)) { operandType = { .type = static_cast(type), .dimensionCount = static_cast(dimensions.size()), .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr, .scale = scale, .zeroPoint = zeroPoint, }; } OperandType(Type type, std::vector dimensions, float scale, int32_t zeroPoint, SymmPerChannelQuantParams&& channelQuant) : OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {} OperandType(Type type, std::vector dimensions, ExtraParams&& extraParams) : OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {} }; class Model : public wrapper::Model { public: using wrapper::Model::Model; // Inherit constructors. int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) { int32_t result; if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension, &result) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } return result; } ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName, uint16_t typeWithinExtension) { ANeuralNetworksOperationType result; if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName, typeWithinExtension, &result) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } return result; } uint32_t addOperand(const OperandType* type) { if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } if (std::holds_alternative(type->extraParams)) { const auto& channelQuant = std::get(type->extraParams); if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams( mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } } else if (std::holds_alternative(type->extraParams)) { const auto& extension = std::get(type->extraParams); if (ANeuralNetworksModel_setOperandExtensionData( mModel, mNextOperandId, extension.data.data(), extension.data.size()) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } } return mNextOperandId++; } }; } // namespace extension_wrapper namespace wrapper { using ExtensionModel = extension_wrapper::Model; using ExtensionOperandType = extension_wrapper::OperandType; using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams; } // namespace wrapper } // namespace nn } // namespace android #endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H