• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <android-base/logging.h>
18 #include <nnapi/OperandTypes.h>
19 #include <nnapi/OperationTypes.h>
20 #include <nnapi/Result.h>
21 #include <nnapi/SharedMemory.h>
22 #include <nnapi/TypeUtils.h>
23 #include <nnapi/Types.h>
24 
25 #include <algorithm>
26 #include <iterator>
27 #include <limits>
28 #include <memory>
29 #include <utility>
30 #include <vector>
31 
32 #include "TestHarness.h"
33 
34 namespace android::nn::test {
35 namespace {
36 
37 using ::test_helper::TestModel;
38 using ::test_helper::TestOperand;
39 using ::test_helper::TestOperandLifeTime;
40 using ::test_helper::TestOperandType;
41 using ::test_helper::TestOperation;
42 using ::test_helper::TestSubgraph;
43 
createOperand(const TestOperand & operand,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)44 Result<Operand> createOperand(const TestOperand& operand, Model::OperandValues* operandValues,
45                               ConstantMemoryBuilder* memoryBuilder) {
46     CHECK(operandValues != nullptr);
47     CHECK(memoryBuilder != nullptr);
48 
49     const OperandType type = static_cast<OperandType>(operand.type);
50     Operand::LifeTime lifetime = static_cast<Operand::LifeTime>(operand.lifetime);
51 
52     DataLocation location;
53     switch (operand.lifetime) {
54         case TestOperandLifeTime::TEMPORARY_VARIABLE:
55         case TestOperandLifeTime::SUBGRAPH_INPUT:
56         case TestOperandLifeTime::SUBGRAPH_OUTPUT:
57         case TestOperandLifeTime::NO_VALUE:
58             break;
59         case TestOperandLifeTime::CONSTANT_COPY:
60         case TestOperandLifeTime::CONSTANT_REFERENCE: {
61             const auto size = operand.data.size();
62             if (size == 0) {
63                 lifetime = Operand::LifeTime::NO_VALUE;
64             } else {
65                 location = (operand.lifetime == TestOperandLifeTime::CONSTANT_COPY)
66                                    ? operandValues->append(operand.data.get<uint8_t>(), size)
67                                    : memoryBuilder->append(operand.data.get<void>(), size);
68             }
69             break;
70         }
71         case TestOperandLifeTime::SUBGRAPH:
72             NN_RET_CHECK(operand.data.get<uint32_t>() != nullptr);
73             NN_RET_CHECK_GE(operand.data.size(), sizeof(uint32_t));
74             location = {.offset = *operand.data.get<uint32_t>()};
75             break;
76     }
77 
78     Operand::ExtraParams extraParams;
79     if (operand.type == TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
80         extraParams =
81                 Operand::SymmPerChannelQuantParams{.scales = operand.channelQuant.scales,
82                                                    .channelDim = operand.channelQuant.channelDim};
83     }
84 
85     return Operand{
86             .type = type,
87             .dimensions = operand.dimensions,
88             .scale = operand.scale,
89             .zeroPoint = operand.zeroPoint,
90             .lifetime = lifetime,
91             .location = location,
92             .extraParams = std::move(extraParams),
93     };
94 }
95 
createSubgraph(const TestSubgraph & testSubgraph,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)96 Result<Model::Subgraph> createSubgraph(const TestSubgraph& testSubgraph,
97                                        Model::OperandValues* operandValues,
98                                        ConstantMemoryBuilder* memoryBuilder) {
99     // Operands.
100     std::vector<Operand> operands;
101     operands.reserve(testSubgraph.operands.size());
102     for (const auto& operand : testSubgraph.operands) {
103         operands.push_back(NN_TRY(createOperand(operand, operandValues, memoryBuilder)));
104     }
105 
106     // Operations.
107     std::vector<Operation> operations;
108     operations.reserve(testSubgraph.operations.size());
109     std::transform(testSubgraph.operations.begin(), testSubgraph.operations.end(),
110                    std::back_inserter(operations), [](const TestOperation& op) -> Operation {
111                        return {.type = static_cast<OperationType>(op.type),
112                                .inputs = op.inputs,
113                                .outputs = op.outputs};
114                    });
115 
116     return Model::Subgraph{.operands = std::move(operands),
117                            .operations = std::move(operations),
118                            .inputIndexes = testSubgraph.inputIndexes,
119                            .outputIndexes = testSubgraph.outputIndexes};
120 }
121 
122 }  // namespace
123 
createModel(const TestModel & testModel)124 GeneralResult<Model> createModel(const TestModel& testModel) {
125     Model::OperandValues operandValues;
126     ConstantMemoryBuilder memoryBuilder(0);
127 
128     Model::Subgraph mainSubgraph =
129             NN_TRY(createSubgraph(testModel.main, &operandValues, &memoryBuilder));
130     std::vector<Model::Subgraph> refSubgraphs;
131     refSubgraphs.reserve(testModel.referenced.size());
132     for (const auto& testSubgraph : testModel.referenced) {
133         refSubgraphs.push_back(
134                 NN_TRY(createSubgraph(testSubgraph, &operandValues, &memoryBuilder)));
135     }
136 
137     // Shared memory.
138     std::vector<SharedMemory> pools;
139     if (!memoryBuilder.empty()) {
140         pools.push_back(NN_TRY(memoryBuilder.finish()));
141     }
142 
143     return Model{.main = std::move(mainSubgraph),
144                  .referenced = std::move(refSubgraphs),
145                  .operandValues = std::move(operandValues),
146                  .pools = std::move(pools),
147                  .relaxComputationFloat32toFloat16 = testModel.isRelaxed};
148 }
149 
createRequest(const TestModel & testModel)150 GeneralResult<Request> createRequest(const TestModel& testModel) {
151     // Model inputs.
152     std::vector<Request::Argument> inputs;
153     inputs.reserve(testModel.main.inputIndexes.size());
154     for (uint32_t operandIndex : testModel.main.inputIndexes) {
155         NN_RET_CHECK_LT(operandIndex, testModel.main.operands.size())
156                 << "createRequest failed because inputIndex of operand " << operandIndex
157                 << " exceeds number of operands " << testModel.main.operands.size();
158 
159         const auto& op = testModel.main.operands[operandIndex];
160         Request::Argument requestArgument;
161         if (op.data.size() == 0) {
162             // Omitted input.
163             requestArgument = {.lifetime = Request::Argument::LifeTime::NO_VALUE};
164         } else {
165             const auto location = DataLocation{.pointer = op.data.get<void>(),
166                                                .length = static_cast<uint32_t>(op.data.size())};
167             requestArgument = {.lifetime = Request::Argument::LifeTime::POINTER,
168                                .location = location,
169                                .dimensions = op.dimensions};
170         }
171         inputs.push_back(std::move(requestArgument));
172     }
173 
174     // Model outputs.
175     std::vector<Request::Argument> outputs;
176     outputs.reserve(testModel.main.outputIndexes.size());
177     MutableMemoryBuilder outputBuilder(0);
178     for (uint32_t operandIndex : testModel.main.outputIndexes) {
179         NN_RET_CHECK_LT(operandIndex, testModel.main.operands.size())
180                 << "createRequest failed because outputIndex of operand " << operandIndex
181                 << " exceeds number of operands " << testModel.main.operands.size();
182 
183         const auto& op = testModel.main.operands[operandIndex];
184 
185         // In the case of zero-sized output, we should at least provide a one-byte buffer.
186         // This is because zero-sized tensors are only supported internally to the driver, or
187         // reported in output shapes. It is illegal for the client to pre-specify a zero-sized
188         // tensor as model output. Otherwise, we will have two semantic conflicts:
189         // - "Zero dimension" conflicts with "unspecified dimension".
190         // - "Omitted operand buffer" conflicts with "zero-sized operand buffer".
191         size_t bufferSize = std::max<size_t>(op.data.size(), 1);
192 
193         const DataLocation location = outputBuilder.append(bufferSize);
194         outputs.push_back({.lifetime = Request::Argument::LifeTime::POOL,
195                            .location = location,
196                            .dimensions = op.dimensions});
197     }
198 
199     // Model pools.
200     std::vector<Request::MemoryPool> pools;
201     if (!outputBuilder.empty()) {
202         pools.push_back(NN_TRY(outputBuilder.finish()));
203     }
204 
205     return Request{
206             .inputs = std::move(inputs), .outputs = std::move(outputs), .pools = std::move(pools)};
207 }
208 
209 }  // namespace android::nn::test
210