• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ModelArgumentInfo"
18 
19 #include "ModelArgumentInfo.h"
20 
21 #include <LegacyUtils.h>
22 
23 #include <algorithm>
24 #include <utility>
25 #include <vector>
26 
27 #include "NeuralNetworks.h"
28 #include "TypeManager.h"
29 
30 namespace android {
31 namespace nn {
32 
33 static const std::pair<int, ModelArgumentInfo> kBadDataModelArgumentInfo{ANEURALNETWORKS_BAD_DATA,
34                                                                          {}};
35 
createFromPointer(const Operand & operand,const ANeuralNetworksOperandType * type,void * data,uint32_t length,bool paddingEnabled)36 std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromPointer(
37         const Operand& operand, const ANeuralNetworksOperandType* type, void* data, uint32_t length,
38         bool paddingEnabled) {
39     if ((data == nullptr) != (length == 0)) {
40         const char* dataPtrMsg = data ? "NOT_NULLPTR" : "NULLPTR";
41         LOG(ERROR) << "Data pointer must be nullptr if and only if length is zero (data = "
42                    << dataPtrMsg << ", length = " << length << ")";
43         return kBadDataModelArgumentInfo;
44     }
45 
46     ModelArgumentInfo ret;
47     uint32_t neededLength = 0;
48     if (data == nullptr) {
49         ret.mState = ModelArgumentInfo::HAS_NO_VALUE;
50     } else {
51         if (int n = ret.updateDimensionInfo(operand, type)) {
52             return {n, ModelArgumentInfo()};
53         }
54         if (operand.type != OperandType::OEM) {
55             neededLength = TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
56             if (neededLength > length) {
57                 LOG(ERROR) << "Setting argument with invalid length: " << length
58                            << ", minimum length expected: " << neededLength;
59                 return kBadDataModelArgumentInfo;
60             }
61         }
62         ret.mState = ModelArgumentInfo::POINTER;
63     }
64     const uint32_t rawLength = neededLength == 0 ? length : neededLength;
65     const uint32_t padding = length - rawLength;
66 
67     if (!paddingEnabled && padding > 0) {
68         LOG(ERROR) << "Setting argument with padded length without enabling input and output "
69                       "padding -- length: "
70                    << length << ", expected length: " << neededLength;
71         return kBadDataModelArgumentInfo;
72     }
73 
74     ret.mBuffer = data;
75     ret.mLocationAndLength = {.poolIndex = 0, .offset = 0, .length = rawLength, .padding = padding};
76     return {ANEURALNETWORKS_NO_ERROR, ret};
77 }
78 
createFromMemory(const Operand & operand,const ANeuralNetworksOperandType * type,uint32_t poolIndex,uint32_t offset,uint32_t length,bool paddingEnabled)79 std::pair<int, ModelArgumentInfo> ModelArgumentInfo::createFromMemory(
80         const Operand& operand, const ANeuralNetworksOperandType* type, uint32_t poolIndex,
81         uint32_t offset, uint32_t length, bool paddingEnabled) {
82     ModelArgumentInfo ret;
83     if (int n = ret.updateDimensionInfo(operand, type)) {
84         return {n, ModelArgumentInfo()};
85     }
86     const bool isMemorySizeKnown = offset != 0 || length != 0;
87     uint32_t neededLength = 0;
88     if (isMemorySizeKnown && operand.type != OperandType::OEM) {
89         neededLength = TypeManager::get()->getSizeOfData(operand.type, ret.mDimensions);
90         if (neededLength > length) {
91             LOG(ERROR) << "Setting argument with invalid length: " << length
92                        << " (offset: " << offset << "), minimum length expected: " << neededLength;
93             return kBadDataModelArgumentInfo;
94         }
95     }
96     const uint32_t rawLength = neededLength == 0 ? length : neededLength;
97     const uint32_t padding = length - rawLength;
98 
99     if (!paddingEnabled && padding > 0) {
100         LOG(ERROR) << "Setting argument with padded length without enabling input and output "
101                       "padding -- length: "
102                    << length << ", offset: " << offset << ", expected length: " << neededLength;
103         return kBadDataModelArgumentInfo;
104     }
105 
106     ret.mState = ModelArgumentInfo::MEMORY;
107     ret.mLocationAndLength = {
108             .poolIndex = poolIndex, .offset = offset, .length = rawLength, .padding = padding};
109     ret.mBuffer = nullptr;
110     return {ANEURALNETWORKS_NO_ERROR, ret};
111 }
112 
updateDimensionInfo(const Operand & operand,const ANeuralNetworksOperandType * newType)113 int ModelArgumentInfo::updateDimensionInfo(const Operand& operand,
114                                            const ANeuralNetworksOperandType* newType) {
115     if (newType == nullptr) {
116         mInitialDimensions = operand.dimensions;
117     } else {
118         const uint32_t count = newType->dimensionCount;
119         mInitialDimensions = std::vector<uint32_t>(count);
120         std::copy(&newType->dimensions[0], &newType->dimensions[count], mInitialDimensions.begin());
121     }
122     mDimensions = mInitialDimensions;
123     return ANEURALNETWORKS_NO_ERROR;
124 }
125 
createRequestArgument() const126 Request::Argument ModelArgumentInfo::createRequestArgument() const {
127     switch (mState) {
128         case ModelArgumentInfo::POINTER: {
129             Request::Argument arg = {.lifetime = Request::Argument::LifeTime::POINTER,
130                                      .location = mLocationAndLength,
131                                      .dimensions = mDimensions};
132             arg.location.pointer = mBuffer;
133             return arg;
134         }
135         case ModelArgumentInfo::MEMORY:
136             return {.lifetime = Request::Argument::LifeTime::POOL,
137                     .location = mLocationAndLength,
138                     .dimensions = mDimensions};
139         case ModelArgumentInfo::HAS_NO_VALUE:
140             return {.lifetime = Request::Argument::LifeTime::NO_VALUE};
141         case ModelArgumentInfo::UNSPECIFIED:
142             LOG(FATAL) << "Invalid state: UNSPECIFIED";
143             return {};
144     };
145     LOG(FATAL) << "Invalid state: " << mState;
146     return {};
147 }
148 
createRequestArguments(const std::vector<ModelArgumentInfo> & argumentInfos,const std::vector<DataLocation> & ptrArgsLocations)149 std::vector<Request::Argument> createRequestArguments(
150         const std::vector<ModelArgumentInfo>& argumentInfos,
151         const std::vector<DataLocation>& ptrArgsLocations) {
152     const size_t count = argumentInfos.size();
153     std::vector<Request::Argument> ioInfos(count);
154     uint32_t ptrArgsIndex = 0;
155     for (size_t i = 0; i < count; i++) {
156         const auto& info = argumentInfos[i];
157         switch (info.state()) {
158             case ModelArgumentInfo::POINTER:
159                 ioInfos[i] = {.lifetime = Request::Argument::LifeTime::POOL,
160                               .location = ptrArgsLocations[ptrArgsIndex++],
161                               .dimensions = info.dimensions()};
162                 break;
163             case ModelArgumentInfo::MEMORY:
164             case ModelArgumentInfo::HAS_NO_VALUE:
165                 ioInfos[i] = info.createRequestArgument();
166                 break;
167             default:
168                 CHECK(false);
169         };
170     }
171     return ioInfos;
172 }
173 
174 }  // namespace nn
175 }  // namespace android
176