1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "MetaModel"
18
19 #include "MetaModel.h"
20
21 #include <algorithm>
22 #include <map>
23 #include <numeric>
24 #include <set>
25 #include <sstream>
26 #include <type_traits>
27 #include <utility>
28 #include <vector>
29
30 #include "GraphDump.h"
31 #include "LegacyUtils.h"
32 #include "nnapi/TypeUtils.h"
33 #include "nnapi/Types.h"
34 #include "nnapi/Validation.h"
35
36 namespace android::nn {
37
38 namespace {
39
40 // Add an element to the end of the vector, set it to the specified value, and
41 // return a pair consisting of the index of the new element and a pointer to the
42 // new element.
43 template <class T>
extend(std::vector<T> * vec,const T & val)44 std::pair<uint32_t, T*> extend(std::vector<T>* vec, const T& val) {
45 vec->push_back(val);
46 return {vec->size() - 1, &vec->back()};
47 }
48
49 // Add an element to the end of the vector and return a pair consisting of the
50 // index of the new element and a pointer to the new element.
51 template <class T>
extend(std::vector<T> * vec)52 std::pair<uint32_t, T*> extend(std::vector<T>* vec) {
53 return extend(vec, {});
54 }
55
invalid(const Model & model,Version version,bool strictSlicing)56 bool invalid(const Model& model, Version version, bool strictSlicing) {
57 // A model must have at least one operation. However, it's possible that a
58 // slice has no operations (because no operations from the original model
59 // are compliant with the sliced model type). In this case, the sliced
60 // model would be invalid.
61 const bool looksEmpty = (model.main.operations.size() == 0);
62 if (strictSlicing) {
63 CHECK_EQ(looksEmpty, (model.main.operands.size() == 0));
64 }
65 if (looksEmpty) return true;
66
67 // A model must have at least one output. However, it's possible for a
68 // model to contain dead operations (i.e., outputs on which no model outputs
69 // are data dependent). A slice might contain only dead operations, and
70 // hence have no model outputs. In this case, the sliced model would be
71 // invalid.
72 if (model.main.outputIndexes.size() == 0) return true;
73
74 // We shouldn't have to check whether the model is valid. However, it could
75 // be invalid if there is an error in the slicing algorithm.
76 auto maybeVersion = validate(model);
77 if (!maybeVersion.has_value()) {
78 LOG(WARNING) << "Sliced model fails validate(): " << maybeVersion.error();
79 CHECK(!strictSlicing);
80 return true;
81 }
82 if (maybeVersion.value() > version) {
83 LOG(WARNING) << "Sliced model fails validate(): insufficient version ("
84 << maybeVersion.value() << " vs " << version << ")";
85 CHECK(!strictSlicing);
86 return true;
87 }
88
89 return false;
90 }
91
92 } // anonymous namespace
93
MetaModel(Model model,bool strictSlicing)94 MetaModel::MetaModel(Model model, bool strictSlicing)
95 : mModel(std::move(model)),
96 mModelMinimumSupportedVersion(validate(mModel).value()),
97 mStrictSlicing(strictSlicing) {}
98
getSlice(Version version) const99 MetaModel::ReturnedSlice MetaModel::getSlice(Version version) const {
100 // All slices of versions of at least mModelMinimumSupportedVersion are identical, so do not
101 // create more than one such slice.
102 version = std::min(version, mModelMinimumSupportedVersion);
103
104 auto& slice = mCachedSlices[version];
105 if (slice.mState == SliceState::UNINITIALIZED) {
106 slice = makeSlice(version);
107 }
108 if (slice.mState == SliceState::INVALID) {
109 return {};
110 }
111 return MetaModel::ReturnedSlice(std::make_pair(
112 slice.mModel, Mapper([&slice](uint32_t slicedOperationIndex) {
113 return slice.mSlicedOperationIndexToOrigIndex.at(slicedOperationIndex);
114 })));
115 }
116
117 // Utility class for makeSlice().
118 //
119 // For each output operand of a noncompliant operation that is the input
120 // operand of at least one compliant operation, we will ensure that there is
121 // a sliced model input whose "type" is that of the output operand. This is
122 // a map from operand "type" (in the original model) to model input operand
123 // index (in the sliced model). We only use the subset of the fields that are
124 // relevant (OperandType, dimensions, scale, zeroPoint, extraParams), but
125 // exclude irrelevant fields from the map key (lifetime, location).
126 //
127 // We also use this map for model input operands of the original model that
128 // become input operands of the sliced model. This means that an original
129 // model input operand might be commoned with other original model input
130 // operands and/or with original model temporary operands.
131 class MetaModel::OrigOperandToSlicedInputOperandIndex {
132 public:
133 // `slicedOperands` and `slicedInputIndexes` will be modified as part of
134 // OrigOperandToSlicedInputOperandIndex::getIndex. `slicedVersion`, `operandValuesSize`, and
135 // `poolSizes` are used as a check to ensure that the sliced operand is valid and compliant with
136 // the sliced version. `operandValuesSize` is the size of the operand values in the sliced model
137 // (which is the same as the original model). `poolSizes` is the size of the memories in the
138 // sliced model (which is the same as the original model).
OrigOperandToSlicedInputOperandIndex(std::vector<Operand> * slicedOperands,std::vector<uint32_t> * slicedInputIndexes,Version slicedVersion,size_t operandValuesSize,std::vector<size_t> poolSizes)139 OrigOperandToSlicedInputOperandIndex(std::vector<Operand>* slicedOperands,
140 std::vector<uint32_t>* slicedInputIndexes,
141 Version slicedVersion, size_t operandValuesSize,
142 std::vector<size_t> poolSizes)
143 : mSlicedOperands(*slicedOperands),
144 mSlicedInputIndexes(*slicedInputIndexes),
145 kSlicedVersion(slicedVersion),
146 kOperandValuesSize(operandValuesSize),
147 kPoolSizes(std::move(poolSizes)) {}
148
149 // Given an operand from the original model, return the index of the
150 // corresponding model input operand from the sliced model. Creates a
151 // new operand in the sliced model if necessary.
getIndex(Operand operand)152 uint32_t getIndex(Operand operand) {
153 CHECK(operand.lifetime == Operand::LifeTime::SUBGRAPH_INPUT ||
154 operand.lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT ||
155 operand.lifetime == Operand::LifeTime::TEMPORARY_VARIABLE);
156
157 // Lookup
158 auto it = mMap.find(operand);
159 if (it != mMap.end()) {
160 VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex looked for "
161 << operand << " and found " << it->second << ": " << it->first;
162 return it->second;
163 }
164
165 // Create
166 operand.lifetime = Operand::LifeTime::SUBGRAPH_INPUT;
167 operand.location = {};
168
169 // Note that the sliced model does not contain any referenced subgraphs, so both `subgraphs`
170 // and `subgraphVersionCache` are empty.
171 const std::vector<Model::Subgraph> subgraphs;
172 auto subgraphVersionCache = createSubgraphVersionCache(subgraphs.size());
173 const auto minimumSupportedOperandVersion =
174 validateOperandAndAnythingItDependsOn(operand, kOperandValuesSize, kPoolSizes,
175 subgraphs, subgraphVersionCache.get())
176 .value();
177 CHECK_LE(minimumSupportedOperandVersion, kSlicedVersion);
178
179 uint32_t slicedOperandIndex = extend(&mSlicedOperands, operand).first;
180 mMap[operand] = slicedOperandIndex;
181 extend(&mSlicedInputIndexes, slicedOperandIndex);
182 VLOG(COMPILATION) << "OrigOperandToSlicedInputOperandIndex::getIndex created "
183 << slicedOperandIndex << ": " << operand;
184 return slicedOperandIndex;
185 }
186
187 private:
188 class Compare {
189 public:
operator ()(const Operand & a,const Operand & b) const190 bool operator()(const Operand& a, const Operand& b) const {
191 if (a.type != b.type) {
192 return a.type < b.type;
193 }
194 if (a.dimensions != b.dimensions) {
195 return a.dimensions < b.dimensions;
196 }
197 if (a.scale != b.scale) {
198 return a.scale < b.scale;
199 }
200 if (a.zeroPoint != b.zeroPoint) {
201 return a.zeroPoint < b.zeroPoint;
202 }
203 return compare(a.extraParams, b.extraParams);
204 }
205
206 private:
compare(const Operand::SymmPerChannelQuantParams & a,const Operand::SymmPerChannelQuantParams & b)207 static bool compare(const Operand::SymmPerChannelQuantParams& a,
208 const Operand::SymmPerChannelQuantParams& b) {
209 if (a.scales != b.scales) {
210 return a.scales < b.scales;
211 }
212 return a.channelDim < b.channelDim;
213 }
compare(const Operand::ExtraParams & a,const Operand::ExtraParams & b)214 static bool compare(const Operand::ExtraParams& a, const Operand::ExtraParams& b) {
215 if (a.index() != b.index()) {
216 return a.index() < b.index();
217 }
218 if (std::holds_alternative<Operand::SymmPerChannelQuantParams>(a)) {
219 return compare(std::get<Operand::SymmPerChannelQuantParams>(a),
220 std::get<Operand::SymmPerChannelQuantParams>(b));
221 }
222 if (std::holds_alternative<Operand::ExtensionParams>(a)) {
223 return std::get<Operand::ExtensionParams>(a) <
224 std::get<Operand::ExtensionParams>(b);
225 }
226 if (std::holds_alternative<Operand::NoParams>(a)) {
227 return false;
228 }
229 CHECK(false) << "Unexpected";
230 return false;
231 }
232 };
233 std::map<Operand, uint32_t, Compare> mMap;
234 std::vector<Operand>& mSlicedOperands;
235 std::vector<uint32_t>& mSlicedInputIndexes;
236 const Version kSlicedVersion;
237 const size_t kOperandValuesSize;
238 const std::vector<size_t> kPoolSizes;
239 };
240
processOperations(Slice * slice,std::map<uint32_t,uint32_t> * origOperandIndexToSlicedIndex,OrigOperandToSlicedInputOperandIndex * origOperandToSlicedInputOperandIndex,const std::set<uint32_t> & noncompliantOperations,const std::set<uint32_t> & inputOperandIndexesOfCompliantOperations) const241 void MetaModel::processOperations(
242 Slice* slice, std::map<uint32_t, uint32_t>* origOperandIndexToSlicedIndex,
243 OrigOperandToSlicedInputOperandIndex* origOperandToSlicedInputOperandIndex,
244 const std::set<uint32_t>& noncompliantOperations,
245 const std::set<uint32_t>& inputOperandIndexesOfCompliantOperations) const {
246 const auto& origOperands = mModel.main.operands;
247 const auto& origOperations = mModel.main.operations;
248 auto& slicedOperands = slice->mModel.main.operands;
249 auto& slicedOperations = slice->mModel.main.operations;
250
251 std::vector<uint32_t> origOperandNumberOfConsumers =
252 countNumberOfConsumers(origOperands.size(), origOperations).value();
253
254 for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
255 ++origOperationIndex) {
256 const Operation& origOperation = origOperations[origOperationIndex];
257
258 if (noncompliantOperations.count(origOperationIndex)) {
259 for (uint32_t output : origOperation.outputs) {
260 if (!inputOperandIndexesOfCompliantOperations.count(output)) {
261 continue;
262 }
263 const uint32_t slicedIndex =
264 origOperandToSlicedInputOperandIndex->getIndex(origOperands[output]);
265 (*origOperandIndexToSlicedIndex)[output] = slicedIndex;
266 VLOG(COMPILATION)
267 << "origOperandIndexToSlicedIndex noncompliant output processing created "
268 << output << " -> " << slicedIndex << ": " << slicedOperands[slicedIndex];
269 }
270 } else {
271 slice->mSlicedOperationIndexToOrigIndex.push_back(origOperationIndex);
272 Operation& slicedOperation = *extend(&slicedOperations).second;
273 CHECK_EQ(slice->mSlicedOperationIndexToOrigIndex.size(), slicedOperations.size());
274
275 slicedOperation.type = origOperation.type;
276
277 // Model is topologically sorted, so all operation inputs must be
278 // present in origOperandIndexToSlicedIndex, and no operation
279 // outputs may be.
280
281 // Operation inputs
282 // - Fill in slicedOperation.inputs
283 slicedOperation.inputs.resize(origOperation.inputs.size());
284 std::transform(
285 origOperation.inputs.begin(), origOperation.inputs.end(),
286 slicedOperation.inputs.begin(),
287 [&origOperandIndexToSlicedIndex, &slicedOperands](uint32_t origOperandIndex) {
288 uint32_t slicedOperandIndex =
289 origOperandIndexToSlicedIndex->at(origOperandIndex);
290 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant input "
291 "processing created "
292 << origOperandIndex << " -> " << slicedOperandIndex
293 << ": " << slicedOperands[slicedOperandIndex];
294 return slicedOperandIndex;
295 });
296
297 // Operation outputs
298 // - Add new operands to slicedOperands
299 // - Update origOperandIndexToSlicedIndex
300 // - Fill in slicedOperation.outputs
301 // - Record as a model output, if necessary
302 const uint32_t firstOutputSlicedOperandIndex = slicedOperands.size();
303 slicedOperands.resize(firstOutputSlicedOperandIndex + origOperation.outputs.size());
304 slicedOperation.outputs.resize(origOperation.outputs.size());
305 for (uint32_t outputNum = 0; outputNum < slicedOperation.outputs.size(); ++outputNum) {
306 uint32_t origOperandIndex = origOperation.outputs[outputNum];
307 uint32_t slicedOperandIndex = firstOutputSlicedOperandIndex + outputNum;
308 auto& slicedOperand = slicedOperands[slicedOperandIndex];
309 const auto& origOperand = origOperands[origOperandIndex];
310 slicedOperand = origOperand;
311
312 CHECK_EQ(origOperandIndexToSlicedIndex->count(origOperandIndex), size_t(0));
313 (*origOperandIndexToSlicedIndex)[origOperandIndex] = slicedOperandIndex;
314 slicedOperation.outputs[outputNum] = slicedOperandIndex;
315
316 const auto subgraphOutputLifetime = Operand::LifeTime::SUBGRAPH_OUTPUT;
317 if (!inputOperandIndexesOfCompliantOperations.count(origOperandIndex) &&
318 origOperandNumberOfConsumers[origOperandIndex] != 0) {
319 // Was consumed only by noncompliant operations; convert to
320 // an output of the sliced model.
321 slicedOperand.lifetime = subgraphOutputLifetime;
322 }
323
324 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex compliant output created "
325 << origOperandIndex << " -> " << slicedOperandIndex << ": "
326 << slicedOperand;
327
328 if (slicedOperand.lifetime == subgraphOutputLifetime) {
329 extend(&slice->mModel.main.outputIndexes, slicedOperandIndex);
330 }
331 }
332 }
333 }
334 }
335
getNoncompliantOperations(Version version) const336 std::set<uint32_t> MetaModel::getNoncompliantOperations(Version version) const {
337 const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel);
338
339 auto subgraphVersionCache = createSubgraphVersionCache(mModel.referenced.size());
340 std::set<uint32_t> noncompliantOperations;
341 for (uint32_t i = 0; i < mModel.main.operations.size(); ++i) {
342 const auto& operation = mModel.main.operations[i];
343 const auto minSupportedVersion =
344 validateOperationAndAnythingItDependsOn(
345 operation, mModel.main.operands, operandValuesSize, poolSizes,
346 mModel.referenced, subgraphVersionCache.get())
347 .value();
348 if (minSupportedVersion > version) {
349 noncompliantOperations.insert(i);
350 }
351 }
352 return noncompliantOperations;
353 }
354
makeSlice(Version version) const355 MetaModel::Slice MetaModel::makeSlice(Version version) const {
356 Slice slice;
357
358 // Quickly return if the model is already compliant with `version`
359 if (version >= mModelMinimumSupportedVersion) {
360 slice.mModel = mModel;
361 slice.mSlicedOperationIndexToOrigIndex =
362 std::vector<uint32_t>(mModel.main.operations.size());
363 std::iota(slice.mSlicedOperationIndexToOrigIndex.begin(),
364 slice.mSlicedOperationIndexToOrigIndex.end(), 0u);
365 slice.mState = SliceState::NORMAL;
366 return slice;
367 }
368
369 const auto& origOperands = mModel.main.operands;
370 const auto& origOperations = mModel.main.operations;
371 auto& slicedOperands = slice.mModel.main.operands;
372
373 // Indexes of elements of noncompliant origOperations
374 std::set<uint32_t> noncompliantOperations = getNoncompliantOperations(version);
375
376 // Check if any compliant operations require a subgraph.
377 bool someCompliantOperationHasASubgraphOperand = false;
378 if (!mModel.referenced.empty()) {
379 for (size_t i = 0; i < mModel.main.operations.size(); ++i) {
380 const auto& operation = mModel.main.operations[i];
381 if (noncompliantOperations.count(i) > 0) {
382 continue;
383 }
384 const auto isSubgraph = [&origOperands](uint32_t opndIdx) {
385 return origOperands[opndIdx].lifetime == Operand::LifeTime::SUBGRAPH;
386 };
387 if (std::any_of(operation.inputs.begin(), operation.inputs.end(), isSubgraph)) {
388 someCompliantOperationHasASubgraphOperand = true;
389 break;
390 }
391 }
392 }
393
394 // TODO(b/175418767): Currently, MetaModel is not equipped to slice referenced subgraphs. If the
395 // original model is not compliant with the specified version and contains referenced subgraphs
396 // needed by the slice, return an invalidated slice.
397 if (someCompliantOperationHasASubgraphOperand) {
398 slice.mState = SliceState::INVALID;
399 return slice;
400 }
401
402 // Map from an operand index in origOperands to the corresponding operand index in
403 // slicedOperands
404 std::map<uint32_t, uint32_t> origOperandIndexToSlicedIndex;
405
406 // Collect the operand indexes of every operand that is an input to a
407 // compliant operation. If the operand is a CONSTANT_*, POINTER, or a
408 // NO_VALUE, copy it to the sliced model and update
409 // origOperandIndexToSlicedIndex accordingly. Otherwise, we'll deal with
410 // the operand in the subsequent "Main loop", where we process operation
411 // outputs (intermediates and model outputs).
412 std::set<uint32_t> inputOperandIndexesOfCompliantOperations;
413 for (uint32_t origOperationIndex = 0; origOperationIndex < origOperations.size();
414 ++origOperationIndex) {
415 if (noncompliantOperations.count(origOperationIndex)) {
416 continue;
417 }
418 for (uint32_t input : origOperations[origOperationIndex].inputs) {
419 if (inputOperandIndexesOfCompliantOperations.insert(input).second) {
420 const Operand& origOperand = origOperands[input];
421 switch (origOperand.lifetime) {
422 case Operand::LifeTime::CONSTANT_COPY:
423 case Operand::LifeTime::CONSTANT_REFERENCE:
424 case Operand::LifeTime::POINTER:
425 case Operand::LifeTime::NO_VALUE: {
426 const uint32_t slicedOperandIndex =
427 extend(&slicedOperands, origOperand).first;
428 origOperandIndexToSlicedIndex[input] = slicedOperandIndex;
429 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex initialization created "
430 << input << " -> " << slicedOperandIndex << ": "
431 << slicedOperands[slicedOperandIndex];
432 break;
433 }
434 default:
435 break;
436 }
437 }
438 }
439 }
440
441 const auto [operandValuesSize, poolSizes] = getMemorySizes(mModel);
442
443 OrigOperandToSlicedInputOperandIndex origOperandToSlicedInputOperandIndex(
444 &slicedOperands, &slice.mModel.main.inputIndexes, version, operandValuesSize,
445 poolSizes);
446
447 // An input of the original model is an input of the sliced model if and
448 // only if it is consumed by at least one compliant operation. Note that in
449 // the sliced model we share all model inputs of the same "type"; and that
450 // we may later add model inputs to the sliced model.
451 for (uint32_t origInputIndex : mModel.main.inputIndexes) {
452 if (inputOperandIndexesOfCompliantOperations.count(origInputIndex)) {
453 const uint32_t slicedIndex =
454 origOperandToSlicedInputOperandIndex.getIndex(origOperands[origInputIndex]);
455 origOperandIndexToSlicedIndex[origInputIndex] = slicedIndex;
456 VLOG(COMPILATION) << "origOperandIndexToSlicedIndex inputIndexes processing created "
457 << origInputIndex << " -> " << slicedIndex << ": "
458 << slicedOperands[slicedIndex];
459 }
460 }
461
462 // Main loop: Process each operation of the original model.
463 processOperations(&slice, &origOperandIndexToSlicedIndex, &origOperandToSlicedInputOperandIndex,
464 noncompliantOperations, inputOperandIndexesOfCompliantOperations);
465
466 // To keep things simple, we copy over these fields as-is. We could instead
467 // opt to regenerate them based on the operands present in the sliced model:
468 // This would be more complex and probably take more computation time, but
469 // it would reduce the size of the sliced model, and hence the time spent
470 // copying it around and potentially passing it across process boundaries.
471 slice.mModel.operandValues = mModel.operandValues;
472 slice.mModel.pools = mModel.pools;
473
474 if (VLOG_IS_ON(COMPILATION)) {
475 {
476 std::ostringstream fromName;
477 fromName << "Slice: From canonical";
478 graphDump(fromName.str().c_str(), mModel);
479 }
480 {
481 std::ostringstream toName;
482 toName << "Slice: To " << version;
483 graphDump(toName.str().c_str(), slice.mModel);
484 }
485 }
486
487 slice.mState = invalid(slice.mModel, version, mStrictSlicing) ? SliceState::INVALID
488 : SliceState::NORMAL;
489
490 return slice;
491 }
492
493 } // namespace android::nn
494