1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ModelArgumentInfo"
18
19 #include "ModelArgumentInfo.h"
20
21 #include <LegacyUtils.h>
22
23 #include <algorithm>
24 #include <utility>
25 #include <vector>
26
27 #include "NeuralNetworks.h"
28 #include "TypeManager.h"
29
30 namespace android {
31 namespace nn {
32
33 static const std::pair<int, ModelArgumentInfo> kBadDataModelArgumentInfo{ANEURALNETWORKS_BAD_DATA,
34 {}};
35
createFromPointer(const Operand & operand,const ANeuralNetworksOperandType * type,void * data,uint32_t length,bool paddingEnabled)36 std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromPointer(
37 const Operand& operand, const ANeuralNetworksOperandType* type, void* data, uint32_t length,
38 bool paddingEnabled) {
39 if ((data == nullptr) != (length == 0)) {
40 const char* dataPtrMsg = data ? "NOT_NULLPTR" : "NULLPTR";
41 LOG(ERROR) << "Data pointer must be nullptr if and only if length is zero (data = "
42 << dataPtrMsg << ", length = " << length << ")";
43 return kBadDataModelArgumentInfo;
44 }
45
46 ModelArgumentInfo ret;
47 uint32_t neededLength = 0;
48 if (data == nullptr) {
49 ret.mState = ModelArgumentInfo::HAS_NO_VALUE;
50 } else {
51 if (int n = ret.updateDimensionInfo(operand, type)) {
52 return {n, ModelArgumentInfo()};
53 }
54 if (operand.type != OperandType::OEM) {
55 neededLength = TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
56 if (neededLength > length) {
57 LOG(ERROR) << "Setting argument with invalid length: " << length
58 << ", minimum length expected: " << neededLength;
59 return kBadDataModelArgumentInfo;
60 }
61 }
62 ret.mState = ModelArgumentInfo::POINTER;
63 }
64 const uint32_t rawLength = neededLength == 0 ? length : neededLength;
65 const uint32_t padding = length - rawLength;
66
67 if (!paddingEnabled && padding > 0) {
68 LOG(ERROR) << "Setting argument with padded length without enabling input and output "
69 "padding -- length: "
70 << length << ", expected length: " << neededLength;
71 return kBadDataModelArgumentInfo;
72 }
73
74 ret.mBuffer = data;
75 ret.mLocationAndLength = {.poolIndex = 0, .offset = 0, .length = rawLength, .padding = padding};
76 return {ANEURALNETWORKS_NO_ERROR, ret};
77 }
78
createFromMemory(const Operand & operand,const ANeuralNetworksOperandType * type,uint32_t poolIndex,uint32_t offset,uint32_t length,bool paddingEnabled)79 std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromMemory(
80 const Operand& operand, const ANeuralNetworksOperandType* type, uint32_t poolIndex,
81 uint32_t offset, uint32_t length, bool paddingEnabled) {
82 ModelArgumentInfo ret;
83 if (int n = ret.updateDimensionInfo(operand, type)) {
84 return {n, ModelArgumentInfo()};
85 }
86 const bool isMemorySizeKnown = offset != 0 || length != 0;
87 uint32_t neededLength = 0;
88 if (isMemorySizeKnown && operand.type != OperandType::OEM) {
89 neededLength = TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
90 if (neededLength > length) {
91 LOG(ERROR) << "Setting argument with invalid length: " << length
92 << " (offset: " << offset << "), minimum length expected: " << neededLength;
93 return kBadDataModelArgumentInfo;
94 }
95 }
96 const uint32_t rawLength = neededLength == 0 ? length : neededLength;
97 const uint32_t padding = length - rawLength;
98
99 if (!paddingEnabled && padding > 0) {
100 LOG(ERROR) << "Setting argument with padded length without enabling input and output "
101 "padding -- length: "
102 << length << ", offset: " << offset << ", expected length: " << neededLength;
103 return kBadDataModelArgumentInfo;
104 }
105
106 ret.mState = ModelArgumentInfo::MEMORY;
107 ret.mLocationAndLength = {
108 .poolIndex = poolIndex, .offset = offset, .length = rawLength, .padding = padding};
109 ret.mBuffer = nullptr;
110 return {ANEURALNETWORKS_NO_ERROR, ret};
111 }
112
updateDimensionInfo(const Operand & operand,const ANeuralNetworksOperandType * newType)113 int ModelArgumentInfo::updateDimensionInfo(const Operand& operand,
114 const ANeuralNetworksOperandType* newType) {
115 if (newType == nullptr) {
116 mInitialDimensions = operand.dimensions;
117 } else {
118 const uint32_t count = newType->dimensionCount;
119 mInitialDimensions = std::vector<uint32_t>(count);
120 std::copy(&newType->dimensions[0], &newType->dimensions[count], mInitialDimensions.begin());
121 }
122 mDimensions = mInitialDimensions;
123 return ANEURALNETWORKS_NO_ERROR;
124 }
125
createRequestArgument() const126 Request::Argument ModelArgumentInfo::createRequestArgument() const {
127 switch (mState) {
128 case ModelArgumentInfo::POINTER: {
129 Request::Argument arg = {.lifetime = Request::Argument::LifeTime::POINTER,
130 .location = mLocationAndLength,
131 .dimensions = mDimensions};
132 arg.location.pointer = mBuffer;
133 return arg;
134 }
135 case ModelArgumentInfo::MEMORY:
136 return {.lifetime = Request::Argument::LifeTime::POOL,
137 .location = mLocationAndLength,
138 .dimensions = mDimensions};
139 case ModelArgumentInfo::HAS_NO_VALUE:
140 return {.lifetime = Request::Argument::LifeTime::NO_VALUE};
141 case ModelArgumentInfo::UNSPECIFIED:
142 LOG(FATAL) << "Invalid state: UNSPECIFIED";
143 return {};
144 };
145 LOG(FATAL) << "Invalid state: " << mState;
146 return {};
147 }
148
createRequestArguments(const std::vector<ModelArgumentInfo> & argumentInfos,const std::vector<DataLocation> & ptrArgsLocations)149 std::vector<Request::Argument> createRequestArguments(
150 const std::vector<ModelArgumentInfo>& argumentInfos,
151 const std::vector<DataLocation>& ptrArgsLocations) {
152 const size_t count = argumentInfos.size();
153 std::vector<Request::Argument> ioInfos(count);
154 uint32_t ptrArgsIndex = 0;
155 for (size_t i = 0; i < count; i++) {
156 const auto& info = argumentInfos[i];
157 switch (info.state()) {
158 case ModelArgumentInfo::POINTER:
159 ioInfos[i] = {.lifetime = Request::Argument::LifeTime::POOL,
160 .location = ptrArgsLocations[ptrArgsIndex++],
161 .dimensions = info.dimensions()};
162 break;
163 case ModelArgumentInfo::MEMORY:
164 case ModelArgumentInfo::HAS_NO_VALUE:
165 ioInfos[i] = info.createRequestArgument();
166 break;
167 default:
168 CHECK(false);
169 };
170 }
171 return ioInfos;
172 }
173
174 } // namespace nn
175 } // namespace android
176