• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ModelBuilder"
18 
19 #include "ModelBuilder.h"
20 
21 #include <GraphDump.h>
22 #include <LegacyUtils.h>
23 
24 #include <algorithm>
25 #include <map>
26 #include <memory>
27 #include <set>
28 #include <utility>
29 #include <vector>
30 
31 #include "CompilationBuilder.h"
32 #include "Manager.h"
33 #include "TypeManager.h"
34 
35 namespace android {
36 namespace nn {
37 
38 // The maximum number of operands and operations that a model may have.
39 const uint32_t MAX_NUMBER_OF_OPERANDS = 0xFFFFFFFE;
40 const uint32_t MAX_NUMBER_OF_OPERATIONS = 0xFFFFFFFE;
41 
42 #define NN_VALIDATE_NULL_OR_SIZED(tag, data, length)                                          \
43     if ((data == nullptr) != (length == 0)) {                                                 \
44         LOG(ERROR) << "ANeuralNetworksModel_" << tag << " " << #data << " is "                \
45                    << (data == nullptr ? "null" : "not null") << " but " << #length << " is " \
46                    << length;                                                                 \
47         return ANEURALNETWORKS_BAD_DATA;                                                      \
48     }
49 
50 template <typename Type>
makeVector(const Type * data,uint32_t length)51 static std::vector<Type> makeVector(const Type* data, uint32_t length) {
52     return length > 0 ? std::vector<Type>(data, data + length) : std::vector<Type>();
53 }
54 
badState(const char * name)55 bool ModelBuilder::badState(const char* name) {
56     if (mCompletedModel) {
57         LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify after model finished";
58         return true;
59     }
60     if (mInvalidModel) {
61         LOG(ERROR) << "ANeuralNetworksModel_" << name << " can't modify an invalid model";
62         return true;
63     }
64     return false;
65 }
66 
getExtensionType(const char * extensionName,uint16_t typeWithinExtension,int32_t * type)67 int ModelBuilder::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
68                                    int32_t* type) {
69     return TypeManager::get()->getExtensionType(extensionName, typeWithinExtension, type)
70                    ? ANEURALNETWORKS_NO_ERROR
71                    : ANEURALNETWORKS_BAD_DATA;
72 }
73 
addOperand(const ANeuralNetworksOperandType & type)74 int ModelBuilder::addOperand(const ANeuralNetworksOperandType& type) {
75     if (badState("addOperand")) {
76         return ANEURALNETWORKS_BAD_STATE;
77     }
78 
79     OperandType operandType = static_cast<OperandType>(type.type);
80     if (isExtension(operandType) && !TypeManager::get()->areExtensionsAllowed()) {
81         LOG(ERROR) << "Extensions are not supported for this process.";
82         return ANEURALNETWORKS_BAD_DATA;
83     }
84     bool isOemOperand =
85             operandType == OperandType::OEM || operandType == OperandType::TENSOR_OEM_BYTE;
86     if (isOemOperand && !mHasOEMOperand) {
87         LOG(WARNING) << "OEM data type is deprecated. Use Extensions instead.";
88     }
89 
90     const Extension::OperandTypeInformation* info = nullptr;
91     if (isExtension(operandType) &&
92         !TypeManager::get()->getExtensionOperandTypeInfo(operandType, &info)) {
93         LOG(ERROR) << "Extension operand type " << operandType << " is not registered";
94         return ANEURALNETWORKS_BAD_DATA;
95     }
96     NN_VALIDATE_NULL_OR_SIZED("addOperand", type.dimensions, type.dimensionCount);
97     Operand operand = {
98             .type = operandType,
99             .dimensions = makeVector(type.dimensions, type.dimensionCount),
100             .scale = type.scale,
101             .zeroPoint = type.zeroPoint,
102             .lifetime = Operand::LifeTime::TEMPORARY_VARIABLE,
103             .location = {.poolIndex = 0, .offset = 0, .length = 0},
104             .extraParams = {},
105     };
106     if (auto result = validateOperandType(operand, info, "ANeuralNetworksModel_addOperand", true);
107         !result.ok()) {
108         LOG(ERROR) << result.error();
109         return ANEURALNETWORKS_BAD_DATA;
110     }
111 
112     size_t idx = mOperands.size();
113     if (idx >= MAX_NUMBER_OF_OPERANDS) {
114         LOG(ERROR) << "ANeuralNetworksModel_addOperand exceed max operands";
115         return ANEURALNETWORKS_BAD_DATA;
116     }
117 
118     mOperands.push_back(std::move(operand));
119     mHasOEMOperand |= isOemOperand;
120     return ANEURALNETWORKS_NO_ERROR;
121 }
122 
setOperandValue(uint32_t index,const void * buffer,size_t length)123 int ModelBuilder::setOperandValue(uint32_t index, const void* buffer, size_t length) {
124     VLOG(MODEL) << __func__ << " for operand " << index << " size " << length;
125     if (badState("setOperandValue")) {
126         return ANEURALNETWORKS_BAD_STATE;
127     }
128 
129     if (index >= operandCount()) {
130         LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index << " of "
131                    << operandCount();
132         return ANEURALNETWORKS_BAD_DATA;
133     }
134     Operand& operand = mOperands[index];
135     NN_VALIDATE_NULL_OR_SIZED("setOperandValue", buffer, length);
136     if (buffer == nullptr) {
137         operand.lifetime = Operand::LifeTime::NO_VALUE;
138         // The location is unused and is set to zeros.
139         operand.location = {.poolIndex = 0, .offset = 0, .length = 0};
140     } else {
141         if (TypeManager::get()->isTensorType(operand.type) &&
142             tensorHasUnspecifiedDimensions(operand)) {
143             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index
144                        << " which has operand type that is not fully specified";
145             return ANEURALNETWORKS_BAD_DATA;
146         }
147         if (length > 0xFFFFFFFF) {
148             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue value length of " << length
149                        << " exceeds max size";
150             return ANEURALNETWORKS_BAD_DATA;
151         }
152         uint32_t valueLength = static_cast<uint32_t>(length);
153         if (operand.type != OperandType::OEM) {
154             uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
155             if (neededLength != valueLength) {
156                 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
157                            << " bytes when needing " << neededLength;
158                 return ANEURALNETWORKS_BAD_DATA;
159             }
160         }
161         if (valueLength <= ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES) {
162             uint32_t existingSize = static_cast<uint32_t>(mSmallOperandValues.size());
163             uint32_t extraBytes = alignBytesNeeded(existingSize, valueLength);
164             mSmallOperandValues.resize(existingSize + extraBytes + valueLength);
165             operand.lifetime = Operand::LifeTime::CONSTANT_COPY;
166             operand.location = {
167                     .poolIndex = 0, .offset = existingSize + extraBytes, .length = valueLength};
168             memcpy(&mSmallOperandValues[operand.location.offset], buffer, valueLength);
169             VLOG(MODEL) << "Copied small value to offset " << operand.location.offset;
170         } else {
171             VLOG(MODEL) << "Saving large value";
172             operand.lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
173             // The values for poolIndex and offset will be set when the model is finished.
174             typedef decltype(operand.location.poolIndex) PoolIndexType;
175             typedef decltype(operand.location.offset) OffsetType;
176             operand.location = {.poolIndex = ~PoolIndexType(0),
177                                 .offset = ~OffsetType(0),
178                                 .length = valueLength};
179             // We keep track of the buffers. We'll allocate the shared memory only
180             // once we know the total size, to avoid needless copies.
181             mLargeOperandValues.push_back(LargeValue{.operandIndex = index, .buffer = buffer});
182         }
183     }
184     return ANEURALNETWORKS_NO_ERROR;
185 }
186 
setOperandValueFromModel(uint32_t index,const ModelBuilder * value)187 int ModelBuilder::setOperandValueFromModel(uint32_t index, const ModelBuilder* value) {
188     VLOG(MODEL) << __func__ << " for operand " << index << " model " << value;
189     if (badState("setOperandValueFromModel")) {
190         return ANEURALNETWORKS_BAD_STATE;
191     }
192     if (!value->mCompletedModel) {
193         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model must be finished";
194         return ANEURALNETWORKS_BAD_STATE;
195     }
196     if (value->mInvalidModel) {
197         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel value model is invalid";
198         return ANEURALNETWORKS_BAD_STATE;
199     }
200     if (index >= operandCount()) {
201         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromModel setting operand " << index
202                    << " of " << operandCount();
203         return ANEURALNETWORKS_BAD_DATA;
204     }
205     Operand& operand = mOperands[index];
206     operand.lifetime = Operand::LifeTime::SUBGRAPH;
207     operand.location = {
208             .poolIndex = 0,
209             .offset = static_cast<uint32_t>(mReferencedModels.size()),
210             .length = 0,
211     };
212     mReferencedModels.push_back(value);
213     mReferencedSubgraphsForValidation.push_back(value->makeModel().main);
214     return ANEURALNETWORKS_NO_ERROR;
215 }
216 
setOperandSymmPerChannelQuantParams(uint32_t index,const ANeuralNetworksSymmPerChannelQuantParams & channelQuant)217 int ModelBuilder::setOperandSymmPerChannelQuantParams(
218         uint32_t index, const ANeuralNetworksSymmPerChannelQuantParams& channelQuant) {
219     if (badState("setOperandSymmPerChannelQuantParams")) {
220         return ANEURALNETWORKS_BAD_STATE;
221     }
222 
223     if (index >= operandCount()) {
224         LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
225                    << "setting per-channel quantization parameters for operand " << index << " of "
226                    << operandCount();
227         return ANEURALNETWORKS_BAD_DATA;
228     }
229     Operand& operand = mOperands[index];
230 
231     NN_VALIDATE_NULL_OR_SIZED("setOperandSymmPerChannelQuantParams", channelQuant.scales,
232                               channelQuant.scaleCount);
233     Operand::SymmPerChannelQuantParams extraParams = {
234             .scales = makeVector(channelQuant.scales, channelQuant.scaleCount),
235             .channelDim = channelQuant.channelDim,
236     };
237     if (auto result = validateOperandSymmPerChannelQuantParams(
238                 operand, extraParams, "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams");
239         !result.ok()) {
240         LOG(ERROR) << result.error();
241         return ANEURALNETWORKS_BAD_DATA;
242     }
243     switch (operand.type) {
244         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
245             operand.extraParams = std::move(extraParams);
246             break;
247         default:
248             LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
249                        << "invalid operand type " << static_cast<int32_t>(operand.type);
250             return ANEURALNETWORKS_BAD_DATA;
251     }
252     return ANEURALNETWORKS_NO_ERROR;
253 }
254 
setOperandExtensionData(uint32_t index,const void * data,size_t length)255 int ModelBuilder::setOperandExtensionData(uint32_t index, const void* data, size_t length) {
256     if (badState("setOperandExtensionData")) {
257         return ANEURALNETWORKS_BAD_STATE;
258     }
259 
260     if (index >= operandCount()) {
261         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
262                    << "setting extension data for operand " << index << " of " << operandCount();
263         return ANEURALNETWORKS_BAD_DATA;
264     }
265     Operand& operand = mOperands[index];
266 
267     if (!isExtension(operand.type)) {
268         LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
269                    << "setting extension data for a base operand type "
270                    << static_cast<int32_t>(operand.type);
271         return ANEURALNETWORKS_BAD_DATA;
272     }
273 
274     NN_VALIDATE_NULL_OR_SIZED("setOperandExtensionData", data, length);
275     if (data == nullptr) {
276         operand.extraParams = {};
277     } else {
278         operand.extraParams = Operand::ExtensionParams(
279                 std::vector<uint8_t>(reinterpret_cast<const uint8_t*>(data),
280                                      reinterpret_cast<const uint8_t*>(data) + length));
281     }
282     return ANEURALNETWORKS_NO_ERROR;
283 }
284 
copyLargeValuesToSharedMemory()285 int ModelBuilder::copyLargeValuesToSharedMemory() {
286     VLOG(MODEL) << __func__ << " has " << mLargeOperandValues.size() << " values.";
287     if (!mLargeOperandValues.empty()) {
288         // Calculate the size of the shared memory needed for all the large values.
289         // Also sets the offset for each value within the memory.
290         size_t poolSize = 0;
291         for (LargeValue& l : mLargeOperandValues) {
292             Operand& operand = mOperands[l.operandIndex];
293             CHECK_EQ(operand.lifetime, Operand::LifeTime::CONSTANT_REFERENCE);
294             poolSize += alignBytesNeeded(poolSize, operand.location.length);
295             operand.location.offset = poolSize;
296             poolSize += operand.location.length;
297         }
298 
299         // Allocate the shared memory.
300         int n;
301         std::tie(n, mLargeValueMemory) = MemoryAshmem::create(poolSize);
302         NN_RETURN_IF_ERROR(n);
303         uint8_t* memoryPointer = mLargeValueMemory->getPointer();
304         uint32_t poolIndex = mMemories.add(mLargeValueMemory.get());
305         VLOG(MODEL) << "Allocated large value pool of size " << poolSize << " at index "
306                     << poolIndex;
307 
308         // Copy the values to this memory.
309         for (LargeValue& l : mLargeOperandValues) {
310             Operand& operand = mOperands[l.operandIndex];
311             operand.location.poolIndex = poolIndex;
312             memcpy(memoryPointer + operand.location.offset, l.buffer, operand.location.length);
313         }
314     }
315     return ANEURALNETWORKS_NO_ERROR;
316 }
317 
setOperandValueFromMemory(uint32_t index,const RuntimeMemory * memory,uint32_t offset,size_t length)318 int ModelBuilder::setOperandValueFromMemory(uint32_t index, const RuntimeMemory* memory,
319                                             uint32_t offset, size_t length) {
320     VLOG(MODEL) << __func__ << " for operand " << index << " offset " << offset << " size "
321                 << length;
322     if (badState("setOperandValueFromMemory")) {
323         return ANEURALNETWORKS_BAD_STATE;
324     }
325 
326     if (index >= operandCount()) {
327         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
328                    << " of " << operandCount();
329         return ANEURALNETWORKS_BAD_DATA;
330     }
331     Operand& operand = mOperands[index];
332     if (TypeManager::get()->isTensorType(operand.type) && tensorHasUnspecifiedDimensions(operand)) {
333         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
334                    << " which has operand type that is not fully specified";
335         return ANEURALNETWORKS_BAD_DATA;
336     }
337     uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
338     if (neededLength != length) {
339         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
340                    << " bytes when needing " << neededLength;
341         return ANEURALNETWORKS_BAD_DATA;
342     }
343     // Set compilation = nullptr to indicate that the memory is used for a model constant.
344     // In this case, IOType::INPUT is a placeholder value that is ignored by the validator.
345     if (!memory->getValidator().validate(/*compilation=*/nullptr, /*placeholder*/ IOType::INPUT,
346                                          index, nullptr, offset, length)) {
347         return ANEURALNETWORKS_BAD_DATA;
348     }
349     operand.lifetime = Operand::LifeTime::CONSTANT_REFERENCE;
350     operand.location = {.poolIndex = mMemories.add(memory),
351                         .offset = offset,
352                         .length = static_cast<uint32_t>(length)};
353     return ANEURALNETWORKS_NO_ERROR;
354 }
355 
addOperation(ANeuralNetworksOperationType type,uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)356 int ModelBuilder::addOperation(ANeuralNetworksOperationType type, uint32_t inputCount,
357                                const uint32_t* inputs, uint32_t outputCount,
358                                const uint32_t* outputs) {
359     if (badState("addOperation")) {
360         return ANEURALNETWORKS_BAD_STATE;
361     }
362 
363     OperationType operationType = static_cast<OperationType>(type);
364     if (isExtension(operationType) && !TypeManager::get()->areExtensionsAllowed()) {
365         LOG(ERROR) << "Extensions are not supported for this process.";
366         return ANEURALNETWORKS_BAD_DATA;
367     }
368     if (operationType == OperationType::OEM_OPERATION && !mHasOEMOperation) {
369         LOG(WARNING) << "OEM_OPERATION is deprecated. Use Extensions instead.";
370     }
371 
372     if (!isExtension(operationType)) {
373         if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM, type)) {
374             LOG(ERROR) << "ANeuralNetworksModel_addOperation invalid operation type " << type;
375             return ANEURALNETWORKS_BAD_DATA;
376         }
377     } else {
378         const Extension* extension;
379         uint16_t extensionPrefix = getExtensionPrefix(static_cast<uint32_t>(operationType));
380         if (!TypeManager::get()->getExtensionInfo(extensionPrefix, &extension)) {
381             LOG(ERROR) << "Extension operation type " << operationType << " is not recognized";
382             return ANEURALNETWORKS_BAD_DATA;
383         }
384     }
385 
386     NN_VALIDATE_NULL_OR_SIZED("addOperation", inputs, inputCount);
387     NN_VALIDATE_NULL_OR_SIZED("addOperation", outputs, outputCount);
388     Operation operation = {
389             .type = operationType,
390             .inputs = makeVector(inputs, inputCount),
391             .outputs = makeVector(outputs, outputCount),
392     };
393     if (auto result = validateOperationButNotOperands(operation, mOperands,
394                                                       mReferencedSubgraphsForValidation);
395         !result.ok()) {
396         LOG(ERROR) << "Invalid Operation: " << result.error();
397         return ANEURALNETWORKS_BAD_DATA;
398     }
399 
400     uint32_t operationIndex = operationCount();
401     if (operationIndex >= MAX_NUMBER_OF_OPERATIONS) {
402         LOG(ERROR) << "ANeuralNetworksModel_addOperation exceed max operations";
403         return ANEURALNETWORKS_BAD_DATA;
404     }
405 
406     mOperations.push_back(std::move(operation));
407     mHasOEMOperation |= (operationType == OperationType::OEM_OPERATION);
408     mHasExtensionOperation |= isExtension(operationType);
409 
410     return ANEURALNETWORKS_NO_ERROR;
411 }
412 
identifyInputsAndOutputs(uint32_t inputCount,const uint32_t * inputs,uint32_t outputCount,const uint32_t * outputs)413 int ModelBuilder::identifyInputsAndOutputs(uint32_t inputCount, const uint32_t* inputs,
414                                            uint32_t outputCount, const uint32_t* outputs) {
415     if (badState("identifyInputsAndOutputs")) {
416         return ANEURALNETWORKS_BAD_STATE;
417     }
418 
419     NN_VALIDATE_NULL_OR_SIZED("identifyInputsAndOutputs", inputs, inputCount);
420     if (auto result = validateOperandList(makeVector(inputs, inputCount), operandCount(),
421                                           "ANeuralNetworksModel_identifyInputsAndOutputs inputs");
422         !result.ok()) {
423         LOG(ERROR) << result.error();
424         return ANEURALNETWORKS_BAD_DATA;
425     }
426     NN_VALIDATE_NULL_OR_SIZED("identifyInputsAndOutputs", outputs, outputCount);
427     if (auto result = validateOperandList(makeVector(outputs, outputCount), operandCount(),
428                                           "ANeuralNetworksModel_identifyInputsAndOutputs outputs");
429         !result.ok()) {
430         LOG(ERROR) << result.error();
431         return ANEURALNETWORKS_BAD_DATA;
432     }
433 
434     // Makes a copy of the index list, validates the arguments, and changes
435     // the lifetime info of the corresponding operand.
436     auto setArguments = [&](std::vector<uint32_t>* indexVector, uint32_t indexCount,
437                             const uint32_t* indexList, Operand::LifeTime lifetime) -> bool {
438         indexVector->resize(indexCount);
439         for (uint32_t i = 0; i < indexCount; i++) {
440             const uint32_t operandIndex = indexList[i];
441             if (operandIndex >= mOperands.size()) {
442                 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set input or "
443                               "output "
444                               "to be "
445                            << operandIndex << " as this exceeds the numbe of operands "
446                            << mOperands.size();
447                 return false;
448             }
449             (*indexVector)[i] = operandIndex;
450             Operand& operand = mOperands[operandIndex];
451             if (operand.lifetime != Operand::LifeTime::TEMPORARY_VARIABLE) {
452                 LOG(ERROR) << "ANeuralNetworksModel_identifyInputsAndOutputs Can't set operand "
453                            << operandIndex
454                            << " to be an input or output.  Check that it's not a constant or "
455                               "already an input or output";
456                 return false;
457             }
458             operand.lifetime = lifetime;
459         }
460         return true;
461     };
462 
463     if (!setArguments(&mInputIndexes, inputCount, inputs, Operand::LifeTime::SUBGRAPH_INPUT) ||
464         !setArguments(&mOutputIndexes, outputCount, outputs, Operand::LifeTime::SUBGRAPH_OUTPUT)) {
465         return ANEURALNETWORKS_BAD_DATA;
466     }
467 
468     return ANEURALNETWORKS_NO_ERROR;
469 }
470 
relaxComputationFloat32toFloat16(bool allow)471 int ModelBuilder::relaxComputationFloat32toFloat16(bool allow) {
472     if (badState("relaxComputationFloat32toFloat16")) {
473         return ANEURALNETWORKS_BAD_STATE;
474     }
475 
476     mRelaxComputationFloat32toFloat16 = allow;
477 
478     return ANEURALNETWORKS_NO_ERROR;
479 }
480 
createCompilation(CompilationBuilder ** compilation,const std::vector<std::shared_ptr<Device>> & devices,bool explicitDeviceList)481 int ModelBuilder::createCompilation(CompilationBuilder** compilation,
482                                     const std::vector<std::shared_ptr<Device>>& devices,
483                                     bool explicitDeviceList) {
484     if (!mCompletedModel || mInvalidModel) {
485         LOG(ERROR) << "ANeuralNetworksCompilation_create passed an unfinished or invalid model";
486         *compilation = nullptr;
487         return ANEURALNETWORKS_BAD_STATE;
488     }
489     *compilation = new (std::nothrow) CompilationBuilder(this, devices, explicitDeviceList);
490     return (*compilation ? ANEURALNETWORKS_NO_ERROR : ANEURALNETWORKS_OUT_OF_MEMORY);
491 }
492 
finish()493 int ModelBuilder::finish() {
494     if (mCompletedModel) {
495         LOG(ERROR) << "ANeuralNetworksModel_finish called more than once";
496         return ANEURALNETWORKS_BAD_STATE;
497     }
498     if (mInvalidModel) {
499         LOG(ERROR) << "ANeuralNetworksModel_finish called on an invalid model";
500         return ANEURALNETWORKS_BAD_STATE;
501     }
502 
503     int n = copyLargeValuesToSharedMemory();
504     if (n != ANEURALNETWORKS_NO_ERROR) {
505         return n;
506     }
507 
508     // We sort the operations so that they will be in the appropriate
509     // order for a single-threaded, op at a time execution.
510     // TODO: we don't need this if we always run the partitioner.
511     if (!sortIntoRunOrder()) {
512         // We expect sortIntoRunOrder() to have logged an appropriate error message.
513         mInvalidModel = true;
514         return ANEURALNETWORKS_BAD_DATA;
515     }
516 
517     // TODO: Modify validation so that it can be called without creating a Model.
518     // NOTE: Must sortIntoRunOrder() before validation; validator expects operations
519     //       to have been sorted.
520     // NOTE: Must copyLargeValuesToSharedMemory() before validation; otherwise,
521     //       a CONSTANT_REFERENCE operand will not have correct .poolIndex, and
522     //       validation will not work properly.
523     const Model modelForValidation = makeModel();
524     if (auto result = validate(modelForValidation); !result.ok()) {
525         LOG(ERROR) << "ANeuralNetworksModel_finish called on invalid model: " << result.error();
526         mInvalidModel = true;
527         return ANEURALNETWORKS_BAD_DATA;
528     }
529     if (VLOG_IS_ON(MODEL)) {
530         graphDump("ModelBuilder::finish", modelForValidation, nullptr);
531     }
532 
533     removeTrailingArgumentsWithDefaultValues();
534 
535     mCompletedModel = true;
536     return ANEURALNETWORKS_NO_ERROR;
537 }
538 
logRemoval(const Operation & operation,uint32_t count,const std::vector<Operand> & operands)539 static void logRemoval(const Operation& operation, uint32_t count,
540                        const std::vector<Operand>& operands) {
541     std::ostringstream message;
542     message << "Operation " << operation.type << " with inputs {";
543     for (uint32_t i = 0; i < operation.inputs.size(); ++i) {
544         if (i != 0) {
545             message << ", ";
546         }
547         message << operands[operation.inputs[i]].type;
548     }
549     message << "} has trailing optional inputs set to default values. Removing " << count
550             << " trailing inputs.";
551     VLOG(MODEL) << message.str();
552 }
553 
removeTrailingArgumentsWithDefaultValues()554 void ModelBuilder::removeTrailingArgumentsWithDefaultValues() {
555     for (Operation& operation : mOperations) {
556         const uint32_t count = getNumTrailingArgumentsToRemove(operation);
557         if (count == 0) {
558             continue;
559         }
560         if (VLOG_IS_ON(MODEL)) {
561             logRemoval(operation, count, mOperands);
562         }
563         const uint32_t inputCount = operation.inputs.size();
564         CHECK_LT(count, inputCount);
565         const uint32_t newInputCount = inputCount - count;
566         operation.inputs.resize(newInputCount);
567     }
568 }
569 
570 // See countMatchingTrailingArguments().
571 enum class TailSpec {
572     BOOL_FALSE,
573     INT32_ONE,
574     INT32_NEGATIVE_ONE,
575 };
576 
577 // See countMatchingTrailingArguments().
matchesSpec(TailSpec spec,const Operand & operand,const std::vector<uint8_t> & mSmallOperandValues)578 static bool matchesSpec(TailSpec spec, const Operand& operand,
579                         const std::vector<uint8_t>& mSmallOperandValues) {
580     const void* valuePtr = nullptr;
581     if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
582         valuePtr = static_cast<const void*>(&mSmallOperandValues[operand.location.offset]);
583     } else if (operand.lifetime == Operand::LifeTime::POINTER) {
584         valuePtr = std::get<const void*>(operand.location.pointer);
585     } else {
586         // CONSTANT_REFERENCE operands are not supported to avoid mapping memory
587         // during compilation.
588         return false;
589     }
590     switch (spec) {
591         case TailSpec::BOOL_FALSE:
592             return operand.type == OperandType::BOOL &&
593                    *static_cast<const bool8*>(valuePtr) == false;
594         case TailSpec::INT32_ONE:
595             return operand.type == OperandType::INT32 &&
596                    *static_cast<const int32_t*>(valuePtr) == 1;
597         case TailSpec::INT32_NEGATIVE_ONE:
598             return operand.type == OperandType::INT32 &&
599                    *static_cast<const int32_t*>(valuePtr) == -1;
600         default:
601             CHECK(false) << "Unhandled TailSpec: " << static_cast<int>(spec);
602     }
603 }
604 
605 // Returns the number of trailing operation inputs that match the specification.
606 //
607 // Example:
608 //     opeation.inputs = {BOOL_TRUE, BOOL_TRUE,  INT32_ONE, INT32_NEGATIVE_ONE}
609 //     tail            =            {BOOL_FALSE, INT32_ONE, INT32_NEGATIVE_ONE}
610 //     tailStartIndex  = 1    matching elements: ^^^^^^^^^  ^^^^^^^^^^^^^^^^^^
countMatchingTrailingArguments(uint32_t tailStartIndex,const std::vector<TailSpec> & tail,const Operation & operation,const std::vector<Operand> & operands,const std::vector<uint8_t> & smallOperandValues)611 static uint32_t countMatchingTrailingArguments(uint32_t tailStartIndex,
612                                                const std::vector<TailSpec>& tail,
613                                                const Operation& operation,
614                                                const std::vector<Operand>& operands,
615                                                const std::vector<uint8_t>& smallOperandValues) {
616     const uint32_t inputCount = operation.inputs.size();
617     uint32_t count = 0;
618     for (uint32_t i = inputCount - 1; i >= tailStartIndex; --i) {
619         const Operand& operand = operands[operation.inputs[i]];
620         if (!matchesSpec(tail[i - tailStartIndex], operand, smallOperandValues)) {
621             break;
622         }
623         ++count;
624     }
625     return count;
626 }
627 
getNumTrailingArgumentsToRemove(const Operation & operation) const628 uint32_t ModelBuilder::getNumTrailingArgumentsToRemove(const Operation& operation) const {
629     const uint32_t inputCount = operation.inputs.size();
630     auto getCount = [this, &operation](uint32_t tailStartIndex, const std::vector<TailSpec>& tail) {
631         return countMatchingTrailingArguments(tailStartIndex, tail, operation, mOperands,
632                                               mSmallOperandValues);
633     };
634     using TS = TailSpec;
635     // Check if the operation has optional arguments that might be set to default
636     // values. Skip the counting if no optional arguments are present.
637     switch (operation.type) {
638         case OperationType::AVERAGE_POOL_2D: {
639             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
640                 // Explicit padding
641                 // API level 29: 10 to 11 inputs
642                 // API level 27: 10 inputs
643                 return getCount(10, {TS::BOOL_FALSE});
644             } else if (inputCount == 8 &&
645                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
646                 // Implicit padding
647                 // API level 29: 7 to 8 inputs
648                 // API level 27: 7 inputs
649                 return getCount(7, {TS::BOOL_FALSE});
650             }
651         } break;
652         case OperationType::CONV_2D: {
653             if (10 < inputCount && inputCount <= 13 &&
654                 mOperands[operation.inputs[7]].type == OperandType::INT32) {
655                 // Explicit padding
656                 // API level 29: 10 to 13 inputs
657                 // API level 27: 10 inputs
658                 uint32_t count = getCount(10, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
659                 // Inputs 11 and 12 must come together.
660                 return inputCount - count == 12 ? 0 : count;
661             } else if (7 < inputCount && inputCount <= 10 &&
662                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
663                 // Implicit padding
664                 // API level 29: 7 to 10 inputs
665                 // API level 27: 7 inputs
666                 uint32_t count = getCount(7, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
667                 // Inputs 8 and 9 must come together.
668                 return inputCount - count == 9 ? 0 : count;
669             }
670         } break;
671         case OperationType::DEPTHWISE_CONV_2D: {
672             if (11 < inputCount && inputCount <= 14 &&
673                 mOperands[operation.inputs[8]].type == OperandType::INT32) {
674                 // Explicit padding
675                 // API level 29: 11 to 14 inputs
676                 // API level 27: 11 inputs
677                 uint32_t count = getCount(11, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
678                 // Inputs 12 and 13 must come together.
679                 return inputCount - count == 13 ? 0 : count;
680             } else if (8 < inputCount && inputCount <= 11 &&
681                        mOperands[operation.inputs[8]].type == OperandType::BOOL) {
682                 // Implicit padding
683                 // API level 29: 8 to 11 inputs
684                 // API level 27: 8 inputs
685                 uint32_t count = getCount(8, {TS::BOOL_FALSE, TS::INT32_ONE, TS::INT32_ONE});
686                 // Inputs 9 and 10 must come together.
687                 return inputCount - count == 10 ? 0 : count;
688             }
689         } break;
690         case OperationType::DEPTH_TO_SPACE: {
691             if (inputCount == 3) {
692                 // API level 29: 2 to 3 inputs
693                 // API level 27: 2 inputs
694                 return getCount(2, {TS::BOOL_FALSE});
695             }
696         } break;
697         case OperationType::L2_NORMALIZATION: {
698             if (inputCount == 2) {
699                 // API level 29: 1 to 2 inputs
700                 // API level 27: 1 inputs
701                 return getCount(1, {TS::INT32_NEGATIVE_ONE});
702             }
703         } break;
704         case OperationType::L2_POOL_2D: {
705             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
706                 // Explicit padding
707                 // API level 29: 10 to 11 inputs
708                 // API level 27: 10 inputs
709                 return getCount(10, {TS::BOOL_FALSE});
710             } else if (inputCount == 8 &&
711                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
712                 // Implicit padding
713                 // API level 29: 7 to 8 inputs
714                 // API level 27: 7 inputs
715                 return getCount(7, {TS::BOOL_FALSE});
716             }
717         } break;
718         case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
719             if (inputCount == 6) {
720                 // API level 29: 5 to 6 inputs
721                 // API level 27: 5 inputs
722                 return getCount(5, {TS::INT32_NEGATIVE_ONE});
723             }
724         } break;
725         case OperationType::MAX_POOL_2D: {
726             if (inputCount == 11 && mOperands[operation.inputs[7]].type == OperandType::INT32) {
727                 // Explicit padding
728                 // API level 29: 10 to 11 inputs
729                 // API level 27: 10 inputs
730                 return getCount(10, {TS::BOOL_FALSE});
731             } else if (inputCount == 8 &&
732                        mOperands[operation.inputs[7]].type == OperandType::BOOL) {
733                 // Implicit padding
734                 // API level 29: 7 to 8 inputs
735                 // API level 27: 7 inputs
736                 return getCount(7, {TS::BOOL_FALSE});
737             }
738         } break;
739         case OperationType::RESIZE_BILINEAR: {
740             if (3 < inputCount && inputCount <= 6) {
741                 // By shape:
742                 //     API level 30: 3 to 6 inputs
743                 //     API level 29: 3 to 4 inputs
744                 //     API level 27: 3 inputs
745                 // By scale:
746                 //     API level 30: 3 to 6 inputs
747                 //     API level 29: 3 to 4 inputs
748                 return getCount(3, {TS::BOOL_FALSE, TS::BOOL_FALSE, TS::BOOL_FALSE});
749             }
750         } break;
751         case OperationType::SOFTMAX: {
752             if (inputCount == 3) {
753                 // API level 29: 2 to 3 inputs
754                 // API level 27: 2 inputs
755                 return getCount(2, {TS::INT32_NEGATIVE_ONE});
756             }
757         } break;
758         case OperationType::SPACE_TO_DEPTH: {
759             if (inputCount == 3) {
760                 // API level 29: 2 to 3 inputs
761                 // API level 27: 2 inputs
762                 return getCount(2, {TS::BOOL_FALSE});
763             }
764         } break;
765         case OperationType::BATCH_TO_SPACE_ND: {
766             if (inputCount == 3) {
767                 // API level 29: 2 to 3 inputs
768                 // API level 28: 2 inputs
769                 return getCount(2, {TS::BOOL_FALSE});
770             }
771         } break;
772         case OperationType::SPACE_TO_BATCH_ND: {
773             if (inputCount == 4) {
774                 // API level 29: 3 to 4 inputs
775                 // API level 28: 3 inputs
776                 return getCount(3, {TS::BOOL_FALSE});
777             }
778         } break;
779         case OperationType::RESIZE_NEAREST_NEIGHBOR: {
780             if (4 < inputCount && inputCount <= 6) {
781                 // By shape or scale
782                 // API level 30: 4 to 6 inputs
783                 // API level 29: 4 inputs
784                 return getCount(4, {TS::BOOL_FALSE, TS::BOOL_FALSE});
785             }
786         } break;
787         default: {
788             // Do nothing.
789         } break;
790     }
791     // No trailing optional arguments to check.
792     return 0;
793 }
794 
sortIntoRunOrder()795 bool ModelBuilder::sortIntoRunOrder() {
796     // Note that this may be called before the model has been
797     // validated, so we must code defensively.  However, we can assume
798     // an Operation's inputs and outputs have legal indices -- this
799     // should have been checked in addOperation().
800 
801     if (!mSortedOperationIndexMap.empty()) {
802         LOG(ERROR) << "Operations were already sorted into run order.";
803         return true;
804     }
805 
806     // Tracks the operations that can be executed.
807     std::vector<uint32_t> sortedOperationIndexMap;
808     std::vector<uint32_t> opsReadyToRun;
809     std::vector<Operation> runOrder;
810 
811     // Tracks how many inputs are needed for each operation to be ready to run.
812     std::multimap<uint32_t, uint32_t> operandToOperations;
813     std::vector<uint32_t> unknownInputCount(operationCount());
814     for (uint32_t operationIndex = 0; operationIndex < operationCount(); operationIndex++) {
815         uint32_t& count = unknownInputCount[operationIndex];
816         count = 0;
817         for (uint32_t operandIndex : mOperations[operationIndex].inputs) {
818             auto lifetime = mOperands[operandIndex].lifetime;
819             if (lifetime == Operand::LifeTime::TEMPORARY_VARIABLE ||
820                 lifetime == Operand::LifeTime::SUBGRAPH_OUTPUT) {
821                 count++;
822                 operandToOperations.insert(
823                         std::pair<uint32_t, uint32_t>(operandIndex, operationIndex));
824             }
825         }
826         if (count == 0) {
827             opsReadyToRun.push_back(operationIndex);
828         }
829     }
830 
831     while (opsReadyToRun.size() > 0) {
832         // Execute the next op
833         int opIndex = opsReadyToRun.back();
834         opsReadyToRun.pop_back();
835         const Operation& operation = mOperations[opIndex];
836 
837         runOrder.push_back(mOperations[opIndex]);
838         sortedOperationIndexMap.push_back(opIndex);
839 
840         // Mark all its outputs as known.
841         for (uint32_t operandIndex : operation.outputs) {
842             auto range = operandToOperations.equal_range(operandIndex);
843             for (auto i = range.first; i != range.second; i++) {
844                 uint32_t& count = unknownInputCount[i->second];
845                 if (--count == 0) {
846                     opsReadyToRun.push_back(i->second);
847                 }
848             }
849         }
850     }
851 
852     if (runOrder.size() != mOperations.size()) {
853         nnAssert(runOrder.size() < mOperations.size());
854         // Graph must contain at least one cycle or one never-written
855         // operand, because there is at least one Operation that never
856         // became ready.
857         LOG(ERROR) << "Graph contains at least one cycle or one never-written operand";
858         return false;
859     }
860 
861     mSortedOperationIndexMap = std::move(sortedOperationIndexMap);
862     mOperations = std::move(runOrder);
863     return true;
864 }
865 
866 // A helper class to simplify state management when creating a Model.
867 class ModelBuilder::ModelMaker {
868    public:
869     static Model run(const ModelBuilder* model);
870 
871    private:
872     static Model::Subgraph makeSubgraph(const ModelBuilder* model);
ModelMaker()873     ModelMaker() {}
874     Model makeModel(const ModelBuilder* mainModel);
875     uint32_t addSubgraph(const ModelBuilder* refModel);
876     void updateOperandLocations(const ModelBuilder* refModel, Model::Subgraph* subgraph);
877     void addExtensions(const ModelBuilder* model);
878     void addExtensionWithPrefix(uint16_t prefix);
879 
880     std::vector<Model::Subgraph> mRefSubgraphs;
881     Model::OperandValues mOperandValues;
882     MemoryTracker mMemories;
883     std::vector<Model::ExtensionNameAndPrefix> mExtensionNameToPrefix;
884     std::set<uint16_t> mPrefixSet;
885 };
886 
makeModel() const887 Model ModelBuilder::makeModel() const {
888     // TODO: Cache the Model to speed up subsequent calls.
889     return ModelMaker::run(this);
890 }
891 
run(const ModelBuilder * model)892 Model ModelBuilder::ModelMaker::run(const ModelBuilder* model) {
893     // run() ensures the state of ModelMaker is destroyed after the call.
894     return ModelMaker().makeModel(model);
895 }
896 
makeModel(const ModelBuilder * mainModel)897 Model ModelBuilder::ModelMaker::makeModel(const ModelBuilder* mainModel) {
898     addExtensions(mainModel);
899     Model model;
900     model.main = makeSubgraph(mainModel);
901     updateOperandLocations(mainModel, &model.main);
902     model.referenced = std::move(mRefSubgraphs);
903     model.operandValues = std::move(mOperandValues);
904     model.pools.reserve(mMemories.size());
905     std::transform(mMemories.begin(), mMemories.end(), std::back_inserter(model.pools),
906                    [](const RuntimeMemory* m) { return m->getMemory(); });
907     model.relaxComputationFloat32toFloat16 = mainModel->mRelaxComputationFloat32toFloat16;
908     model.extensionNameToPrefix = std::move(mExtensionNameToPrefix);
909     return model;
910 }
911 
makeSubgraph(const ModelBuilder * model)912 Model::Subgraph ModelBuilder::ModelMaker::makeSubgraph(const ModelBuilder* model) {
913     Model::Subgraph subgraph;
914     subgraph.operands = model->mOperands;
915     subgraph.operations = model->mOperations;
916     subgraph.inputIndexes = model->mInputIndexes;
917     subgraph.outputIndexes = model->mOutputIndexes;
918     return subgraph;
919 }
920 
updateOperandLocations(const ModelBuilder * refModel,Model::Subgraph * subgraph)921 void ModelBuilder::ModelMaker::updateOperandLocations(const ModelBuilder* refModel,
922                                                       Model::Subgraph* subgraph) {
923     for (Operand& operand : subgraph->operands) {
924         if (operand.lifetime == Operand::LifeTime::CONSTANT_COPY) {
925             uint32_t valueLength = operand.location.length;
926             uint32_t originalOffset = operand.location.offset;
927             operand.location = mOperandValues.append(&refModel->mSmallOperandValues[originalOffset],
928                                                      valueLength);
929         } else if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
930             uint32_t originalPoolIndex = operand.location.poolIndex;
931             operand.location.poolIndex = mMemories.add(refModel->mMemories[originalPoolIndex]);
932         }
933     }
934     // Do recursive calls at the end to improve locality of mOperandValues.
935     for (Operand& operand : subgraph->operands) {
936         if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
937             uint32_t refModelIndex = operand.location.offset;
938             // TODO(b/147875885): Avoid creating duplicate refSubgraphs when
939             // a single refModel is referenced multiple times.
940             operand.location.offset = addSubgraph(refModel->mReferencedModels[refModelIndex]);
941         }
942     }
943 }
944 
addSubgraph(const ModelBuilder * refModel)945 uint32_t ModelBuilder::ModelMaker::addSubgraph(const ModelBuilder* refModel) {
946     uint32_t index = mRefSubgraphs.size();
947     mRefSubgraphs.push_back(makeSubgraph(refModel));
948     updateOperandLocations(refModel, &mRefSubgraphs.back());
949     return index;
950 }
951 
addExtensions(const ModelBuilder * model)952 void ModelBuilder::ModelMaker::addExtensions(const ModelBuilder* model) {
953     for (const auto& operand : model->mOperands) {
954         if (isExtension(operand.type)) {
955             addExtensionWithPrefix(static_cast<uint32_t>(operand.type) >> kExtensionTypeBits);
956         }
957     }
958     for (const auto& operation : model->mOperations) {
959         if (isExtension(operation.type)) {
960             addExtensionWithPrefix(static_cast<uint32_t>(operation.type) >> kExtensionTypeBits);
961         }
962     }
963     for (const auto& refModel : model->mReferencedModels) {
964         addExtensions(refModel);
965     }
966 }
967 
addExtensionWithPrefix(uint16_t prefix)968 void ModelBuilder::ModelMaker::addExtensionWithPrefix(uint16_t prefix) {
969     if (!mPrefixSet.insert(prefix).second) {
970         return;
971     }
972     const Extension* extension;
973     CHECK(TypeManager::get()->getExtensionInfo(prefix, &extension));
974     mExtensionNameToPrefix.push_back({
975             .name = extension->name,
976             .prefix = prefix,
977     });
978 }
979 
980 #undef NN_VALIDATE_NULL_OR_SIZED
981 
982 }  // namespace nn
983 }  // namespace android
984