• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
18 #define ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
19 
20 #include <utility>
21 #include <variant>
22 #include <vector>
23 
24 #include "NeuralNetworksExtensions.h"
25 #include "NeuralNetworksWrapper.h"
26 
27 namespace android {
28 namespace nn {
29 namespace extension_wrapper {
30 
31 using wrapper::SymmPerChannelQuantParams;
32 using wrapper::Type;
33 
34 struct ExtensionOperandParams {
35     std::vector<uint8_t> data;
36 
ExtensionOperandParamsExtensionOperandParams37     ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {}
38 
39     template <typename T>
ExtensionOperandParamsExtensionOperandParams40     ExtensionOperandParams(const T& data)
41         : ExtensionOperandParams(
42                   std::vector(reinterpret_cast<const uint8_t*>(&data),
43                               reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) {
44         static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable");
45     }
46 };
47 
48 struct OperandType {
49     using ExtraParams =
50             std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>;
51 
52     ANeuralNetworksOperandType operandType;
53     std::vector<uint32_t> dimensions;
54     ExtraParams extraParams;
55 
OperandTypeOperandType56     OperandType(const OperandType& other)
57         : operandType(other.operandType),
58           dimensions(other.dimensions),
59           extraParams(other.extraParams) {
60         operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
61     }
62 
63     OperandType& operator=(const OperandType& other) {
64         if (this != &other) {
65             operandType = other.operandType;
66             dimensions = other.dimensions;
67             extraParams = other.extraParams;
68             operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr;
69         }
70         return *this;
71     }
72 
73     OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0,
74                 ExtraParams&& extraParams = std::monostate())
dimensionsOperandType75         : dimensions(std::move(d)), extraParams(std::move(extraParams)) {
76         operandType = {
77                 .type = static_cast<int32_t>(type),
78                 .dimensionCount = static_cast<uint32_t>(dimensions.size()),
79                 .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr,
80                 .scale = scale,
81                 .zeroPoint = zeroPoint,
82         };
83     }
84 
OperandTypeOperandType85     OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint,
86                 SymmPerChannelQuantParams&& channelQuant)
87         : OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {}
88 
OperandTypeOperandType89     OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams)
90         : OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {}
91 };
92 
93 class Model : public wrapper::Model {
94    public:
95     using wrapper::Model::Model;  // Inherit constructors.
96 
getExtensionOperandType(const char * extensionName,uint16_t typeWithinExtension)97     int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) {
98         int32_t result;
99         if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension,
100                                                          &result) != ANEURALNETWORKS_NO_ERROR) {
101             mValid = false;
102         }
103         return result;
104     }
105 
getExtensionOperationType(const char * extensionName,uint16_t typeWithinExtension)106     ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName,
107                                                            uint16_t typeWithinExtension) {
108         ANeuralNetworksOperationType result;
109         if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName,
110                                                            typeWithinExtension,
111                                                            &result) != ANEURALNETWORKS_NO_ERROR) {
112             mValid = false;
113         }
114         return result;
115     }
116 
addOperand(const OperandType * type)117     uint32_t addOperand(const OperandType* type) {
118         if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
119             ANEURALNETWORKS_NO_ERROR) {
120             mValid = false;
121         }
122         if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) {
123             const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams);
124             if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
125                         mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) {
126                 mValid = false;
127             }
128         } else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) {
129             const auto& extension = std::get<ExtensionOperandParams>(type->extraParams);
130             if (ANeuralNetworksModel_setOperandExtensionData(
131                         mModel, mNextOperandId, extension.data.data(), extension.data.size()) !=
132                 ANEURALNETWORKS_NO_ERROR) {
133                 mValid = false;
134             }
135         }
136         return mNextOperandId++;
137     }
138 };
139 
140 }  // namespace extension_wrapper
141 
142 namespace wrapper {
143 
144 using ExtensionModel = extension_wrapper::Model;
145 using ExtensionOperandType = extension_wrapper::OperandType;
146 using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams;
147 
148 }  // namespace wrapper
149 }  // namespace nn
150 }  // namespace android
151 
152 #endif  //  ANDROID_FRAMEWORKS_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H
153