1 /*
2 * Copyright (C) 2022 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ModelUtils"
18
19 #include "ModelUtils.h"
20
21 #include <android-base/logging.h>
22
23 #include <algorithm>
24 #include <numeric>
25 #include <unordered_set>
26 #include <utility>
27 #include <vector>
28
29 #include "nnapi/TypeUtils.h"
30 #include "nnapi/Types.h"
31 #include "nnapi/Validation.h"
32
33 namespace android::nn {
34 namespace {
35
36 // Map each `true` value in `includes` with a unique integer. `false` values are ignored. E.g.:
37 // includes = {false, true, true, false, true}
38 // returned = { X, 0, 1, X, 2}
getMapping(const std::vector<bool> & includes)39 std::vector<uint32_t> getMapping(const std::vector<bool>& includes) {
40 std::vector<uint32_t> mapping;
41 mapping.reserve(includes.size());
42 std::transform_exclusive_scan(includes.begin(), includes.end(), std::back_inserter(mapping), 0u,
43 std::plus<>{}, [](bool included) { return included ? 1u : 0u; });
44 return mapping;
45 }
46
47 // Remap indexes in `indexes` by the mapping `mapping`.
48 // Precondition: indexes != nullptr
remapIndexes(std::vector<uint32_t> * indexes,const std::vector<uint32_t> & mapping)49 void remapIndexes(std::vector<uint32_t>* indexes, const std::vector<uint32_t>& mapping) {
50 CHECK(indexes != nullptr);
51 for (uint32_t& index : (*indexes)) {
52 index = mapping.at(index);
53 }
54 }
55
56 // Keep elements from `elements` specified by `elementsToKeep`, removing all other elements.
57 // Precondition: elements != nullptr
58 // Precondition: elements->size() == elementsToKeep.size()
59 template <typename Type>
keepSelectedElements(std::vector<Type> * elements,const std::vector<bool> & elementsToKeep)60 void keepSelectedElements(std::vector<Type>* elements, const std::vector<bool>& elementsToKeep) {
61 CHECK(elements != nullptr);
62 CHECK_EQ(elements->size(), elementsToKeep.size());
63
64 size_t elementsCopied = 0;
65 for (size_t i = 0; i < elementsToKeep.size(); ++i) {
66 if (elementsToKeep[i]) {
67 if (elementsCopied != i) {
68 (*elements)[elementsCopied] = std::move((*elements)[i]);
69 }
70 elementsCopied++;
71 }
72 }
73 elements->resize(elementsCopied);
74 }
75
76 // Find which operands in model.main.operands are read or written by model.main.operations and
77 // model.main.inputIndexes.
78 // Postcondition: returned.size() == model.main.operands.size()
identifyUsedOperands(const Model & model)79 std::vector<bool> identifyUsedOperands(const Model& model) {
80 std::vector<bool> used(model.main.operands.size(), false);
81 auto markUsed = [&used](const std::vector<uint32_t>& indexes) {
82 std::for_each(indexes.begin(), indexes.end(),
83 [&used](uint32_t index) { used.at(index) = true; });
84 };
85 for (const auto& operation : model.main.operations) {
86 markUsed(operation.inputs);
87 markUsed(operation.outputs);
88 }
89 markUsed(model.main.inputIndexes);
90 CHECK_EQ(used.size(), model.main.operands.size());
91 return used;
92 }
93
94 // Forward declaration.
95 void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
96 std::vector<bool>* used);
97
98 // Helper function to find which subgraphs are reachable by `operands`.
99 // Precondition: used != nullptr
100 // Precondition: subgraphs.size() == used->size()
identifyUsedSubgraphs(const std::vector<Operand> & operands,const std::vector<Model::Subgraph> & subgraphs,std::vector<bool> * used)101 void identifyUsedSubgraphs(const std::vector<Operand>& operands,
102 const std::vector<Model::Subgraph>& subgraphs, std::vector<bool>* used) {
103 for (const auto& operand : operands) {
104 if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
105 identifyUsedSubgraphs(operand.location.offset, subgraphs, used);
106 }
107 }
108 }
109
110 // Helper function to find which subgraphs are reachable by the subgraph at the `current` index, and
111 // store when a subgraph is used in `used`. `used` also acts as a cache, ensuring each subgraph is
112 // processed at most once.
113 // Precondition: used != nullptr
114 // Precondition: subgraphs.size() == used->size()
115 // Precondition: current < subgraphs.size()
identifyUsedSubgraphs(uint32_t current,const std::vector<Model::Subgraph> & subgraphs,std::vector<bool> * used)116 void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
117 std::vector<bool>* used) {
118 CHECK(used != nullptr);
119 CHECK_EQ(subgraphs.size(), used->size());
120 CHECK_LT(current, subgraphs.size());
121
122 // If a subgraph was already marked as used, quickly return to avoid redundant processing.
123 if ((*used)[current]) {
124 return;
125 }
126
127 // Mark the current subgraph as used, then process any subgraph it references recursively.
128 (*used)[current] = true;
129 identifyUsedSubgraphs(subgraphs[current].operands, subgraphs, used);
130 }
131
132 // Find which subgraphs are reachable by the main operands of `model`.
133 // Postcondition: returned.size() == model.referenced.size()
identifyUsedSubgraphs(const Model & model)134 std::vector<bool> identifyUsedSubgraphs(const Model& model) {
135 std::vector<bool> used(model.referenced.size(), false);
136 identifyUsedSubgraphs(model.main.operands, model.referenced, &used);
137 CHECK_EQ(used.size(), model.referenced.size());
138 return used;
139 }
140
141 // Helper function to find which pools are used by `subgraph`, and store when a pool is used in
142 // `used`.
143 // Precondition: used != nullptr
identifyUsedPools(const Model::Subgraph & subgraph,std::vector<bool> * used)144 void identifyUsedPools(const Model::Subgraph& subgraph, std::vector<bool>* used) {
145 CHECK(used != nullptr);
146 for (const auto& operand : subgraph.operands) {
147 if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
148 used->at(operand.location.poolIndex) = true;
149 }
150 }
151 }
152
153 // Find which pools are used by `model`.
154 // Postcondition: returned.size() == model.pools.size()
identifyUsedPools(const Model & model)155 std::vector<bool> identifyUsedPools(const Model& model) {
156 std::vector<bool> used(model.pools.size(), false);
157 identifyUsedPools(model.main, &used);
158 for (const auto& subgraph : model.referenced) {
159 identifyUsedPools(subgraph, &used);
160 }
161 CHECK_EQ(used.size(), model.pools.size());
162 return used;
163 }
164
165 // Fix the DataLocation in `operand` by either remapping an index or by copying constant data.
166 // Precondition: operand != nullptr
167 // Precondition: newOperandValues != nullptr
fixOperandDataLocation(Operand * operand,Model::OperandValues * newOperandValues,const Model::OperandValues & oldOperandValues,const std::vector<uint32_t> & remappedPoolIndex,const std::vector<uint32_t> & remappedSubgraphIndex)168 void fixOperandDataLocation(Operand* operand, Model::OperandValues* newOperandValues,
169 const Model::OperandValues& oldOperandValues,
170 const std::vector<uint32_t>& remappedPoolIndex,
171 const std::vector<uint32_t>& remappedSubgraphIndex) {
172 CHECK(operand != nullptr);
173 CHECK(newOperandValues != nullptr);
174
175 switch (operand->lifetime) {
176 case Operand::LifeTime::CONSTANT_COPY: {
177 const uint8_t* data = oldOperandValues.data() + operand->location.offset;
178 const uint32_t length = operand->location.length;
179 operand->location = newOperandValues->append(data, length);
180 break;
181 }
182 case Operand::LifeTime::CONSTANT_REFERENCE:
183 operand->location.poolIndex = remappedPoolIndex.at(operand->location.poolIndex);
184 break;
185 case Operand::LifeTime::SUBGRAPH: {
186 uint32_t& subgraphIndex = operand->location.offset;
187 subgraphIndex = remappedSubgraphIndex.at(subgraphIndex);
188 break;
189 }
190 case Operand::LifeTime::TEMPORARY_VARIABLE:
191 case Operand::LifeTime::SUBGRAPH_INPUT:
192 case Operand::LifeTime::SUBGRAPH_OUTPUT:
193 case Operand::LifeTime::NO_VALUE:
194 case Operand::LifeTime::POINTER:
195 break;
196 }
197 }
198
199 // Fix all DataLocations in `operands` by either remapping an index or by copying constant data.
200 // Precondition: operands != nullptr
201 // Precondition: newOperandValues != nullptr
fixOperandDataLocations(std::vector<Operand> * operands,Model::OperandValues * newOperandValues,const Model::OperandValues & oldOperandValues,const std::vector<uint32_t> & remappedPoolIndex,const std::vector<uint32_t> & remappedSubgraphIndex)202 void fixOperandDataLocations(std::vector<Operand>* operands, Model::OperandValues* newOperandValues,
203 const Model::OperandValues& oldOperandValues,
204 const std::vector<uint32_t>& remappedPoolIndex,
205 const std::vector<uint32_t>& remappedSubgraphIndex) {
206 for (Operand& operand : (*operands)) {
207 fixOperandDataLocation(&operand, newOperandValues, oldOperandValues, remappedPoolIndex,
208 remappedSubgraphIndex);
209 }
210 }
211
212 // Fix all operands' DataLocations in `model` by either remapping an index or by copying constant
213 // data.
214 // Precondition: model != nullptr
fixOperandDataLocations(Model * model,const std::vector<uint32_t> & remappedPoolIndex,const std::vector<uint32_t> & remappedSubgraphIndex)215 void fixOperandDataLocations(Model* model, const std::vector<uint32_t>& remappedPoolIndex,
216 const std::vector<uint32_t>& remappedSubgraphIndex) {
217 const auto operandValues = std::exchange(model->operandValues, Model::OperandValues{});
218 fixOperandDataLocations(&model->main.operands, &model->operandValues, operandValues,
219 remappedPoolIndex, remappedSubgraphIndex);
220 for (auto& subgraph : model->referenced) {
221 fixOperandDataLocations(&subgraph.operands, &model->operandValues, operandValues,
222 remappedPoolIndex, remappedSubgraphIndex);
223 }
224 }
225
226 // Find which extensions are used in `model`.
227 // Postcondition: returned.size() == model.extensionNameToPrefix.size()
identifyUsedExtensions(const Model & model)228 std::vector<bool> identifyUsedExtensions(const Model& model) {
229 std::unordered_set<uint16_t> prefixes;
230 const auto collectPrefix = [&prefixes](const auto& operandOrOperation) {
231 const auto prefix = getExtensionPrefix(static_cast<uint32_t>(operandOrOperation.type));
232 constexpr uint16_t kStandardPrefix = 0u;
233 if (prefix != kStandardPrefix) {
234 prefixes.insert(prefix);
235 }
236 };
237 const auto collectPrefixes = [collectPrefix](const Model::Subgraph& subgraph) {
238 std::for_each(subgraph.operands.begin(), subgraph.operands.end(), collectPrefix);
239 std::for_each(subgraph.operations.begin(), subgraph.operations.end(), collectPrefix);
240 };
241
242 collectPrefixes(model.main);
243 for (const auto& subgraph : model.referenced) {
244 collectPrefixes(subgraph);
245 }
246
247 std::vector<bool> used;
248 used.reserve(model.extensionNameToPrefix.size());
249 for (const auto& extension : model.extensionNameToPrefix) {
250 used.push_back(prefixes.count(extension.prefix) > 0);
251 }
252 CHECK_EQ(used.size(), model.extensionNameToPrefix.size());
253 return used;
254 }
255
256 } // anonymous namespace
257
removeDeadOperands(Model * model)258 void removeDeadOperands(Model* model) {
259 CHECK(model != nullptr);
260
261 // Keep only the operands which are used.
262 const auto operandsUsed = identifyUsedOperands(*model);
263 keepSelectedElements(&model->main.operands, operandsUsed);
264
265 // Fix operand indexes.
266 const auto mappedOperandIndices = getMapping(operandsUsed);
267 for (auto& operation : model->main.operations) {
268 remapIndexes(&operation.inputs, mappedOperandIndices);
269 remapIndexes(&operation.outputs, mappedOperandIndices);
270 }
271 remapIndexes(&model->main.inputIndexes, mappedOperandIndices);
272 remapIndexes(&model->main.outputIndexes, mappedOperandIndices);
273
274 // Keep only the subgraphs which are used.
275 const auto subgraphsUsed = identifyUsedSubgraphs(*model);
276 keepSelectedElements(&model->referenced, subgraphsUsed);
277
278 // Keep only the pools which are used.
279 const auto poolsUsed = identifyUsedPools(*model);
280 keepSelectedElements(&model->pools, poolsUsed);
281
282 // Fix operand locations.
283 const auto mappedPoolIndices = getMapping(poolsUsed);
284 const auto mappedSubgraphIndices = getMapping(subgraphsUsed);
285 fixOperandDataLocations(model, mappedPoolIndices, mappedSubgraphIndices);
286
287 // Keep only the extensionNameToPrefixes which are used.
288 const auto extensionsUsed = identifyUsedExtensions(*model);
289 keepSelectedElements(&model->extensionNameToPrefix, extensionsUsed);
290 }
291
292 } // namespace android::nn
293