1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H 18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H 19 20 #include <utility> 21 22 #include "OperationsExecutionUtils.h" 23 #include "OperationsValidationUtils.h" 24 25 namespace android { 26 namespace nn { 27 28 // Encapsulates an operation implementation. 29 struct OperationRegistration { 30 OperationType type; 31 const char* name; 32 33 // Validates operand types, shapes, and any values known during graph creation. 34 // TODO(b/213938830): operation validation dispatch is duplicated and does not handle extension 35 // types. 36 std::function<Result<Version>(const IOperationValidationContext*)> validate; 37 38 // prepare is called when the inputs this operation depends on have been 39 // computed. Typically, prepare does any remaining validation and sets 40 // output shapes via context->setOutputShape(...). 41 std::function<bool(IOperationExecutionContext*)> prepare; 42 43 // Executes the operation, reading from context->getInputBuffer(...) 44 // and writing to context->getOutputBuffer(...). 45 std::function<bool(IOperationExecutionContext*)> execute; 46 47 struct Flag { 48 // Whether the operation allows at least one operand to be omitted. 49 bool allowOmittedOperand = false; 50 // Whether the operation allows at least one input operand to be a zero-sized tensor. 51 bool allowZeroSizedInput = false; 52 } flags; 53 OperationRegistrationOperationRegistration54 OperationRegistration( 55 OperationType type, const char* name, 56 std::function<Result<Version>(const IOperationValidationContext*)> validate, 57 std::function<bool(IOperationExecutionContext*)> prepare, 58 std::function<bool(IOperationExecutionContext*)> execute, Flag flags) 59 : type(type), 60 name(name), 61 validate(std::move(validate)), 62 prepare(std::move(prepare)), 63 execute(std::move(execute)), 64 flags(flags) {} 65 }; 66 67 // A registry of operation implementations. 68 class IOperationResolver { 69 public: 70 virtual const OperationRegistration* findOperation(OperationType operationType) const = 0; ~IOperationResolver()71 virtual ~IOperationResolver() {} 72 }; 73 74 // A registry of builtin operation implementations. 75 // 76 // Note that some operations bypass BuiltinOperationResolver (b/124041202). 77 // 78 // Usage: 79 // const OperationRegistration* operationRegistration = 80 // BuiltinOperationResolver::get()->findOperation(operationType); 81 // NN_RET_CHECK(operationRegistration != nullptr); 82 // NN_RET_CHECK(operationRegistration->validate != nullptr); 83 // NN_RET_CHECK(operationRegistration->validate(&context)); 84 // 85 class BuiltinOperationResolver : public IOperationResolver { 86 DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver); 87 88 public: get()89 static const BuiltinOperationResolver* get() { 90 static BuiltinOperationResolver instance; 91 return &instance; 92 } 93 94 const OperationRegistration* findOperation(OperationType operationType) const override; 95 96 // The number of operation types (OperationCode) defined in NeuralNetworksTypes.h. 97 static constexpr int kNumberOfOperationTypes = 106; 98 99 #ifdef NN_EXPERIMENTAL_FEATURE 100 // The number of experimental operation types (ANeuralNetworksExperimentalOperationCode) defined 101 // in NeuralNetworksExperimentalFeatures.h. 102 static constexpr int kNumberOfExperimentalOperationTypes = 1; 103 104 // The starting value of experimental operation types (ANeuralNetworksExperimentalOperationCode) 105 // defined in NeuralNetworksExperimentalFeatures.h. 106 static constexpr int kStartOfExperimentalOperations = 20000; 107 #endif // NN_EXPERIMENTAL_FEATURE 108 109 private: 110 BuiltinOperationResolver(); 111 112 void registerOperation(const OperationRegistration* operationRegistration); 113 114 const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {}; 115 116 #ifdef NN_EXPERIMENTAL_FEATURE 117 const OperationRegistration* mExperimentalRegistrations[kNumberOfExperimentalOperationTypes] = 118 {}; 119 #endif // NN_EXPERIMENTAL_FEATURE 120 }; 121 122 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by 123 // OperationResolver. 124 // 125 // Usage: 126 // (check OperationRegistration::Flag for available fields and default values.) 127 // 128 // - With default flags. 129 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 130 // foo_op::prepare, foo_op::execute); 131 // 132 // - With a customized flag. 133 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 134 // foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true); 135 // 136 // - With multiple customized flags. 137 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 138 // foo_op::prepare, foo_op::execute, .allowOmittedOperand = true, 139 // .allowZeroSizedInput = true); 140 // 141 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION 142 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \ 143 const OperationRegistration* register_##identifier() { \ 144 static OperationRegistration registration(OperationType::identifier, operationName, \ 145 validate, prepare, execute, {__VA_ARGS__}); \ 146 return ®istration; \ 147 } 148 #else 149 // This version ignores CPU execution logic (prepare and execute). 150 // The compiler is supposed to omit that code so that only validation logic 151 // makes it into libneuralnetworks_common*. 152 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \ 153 ...) \ 154 const OperationRegistration* register_##identifier() { \ 155 static OperationRegistration registration(OperationType::identifier, operationName, \ 156 validate, nullptr, nullptr, {__VA_ARGS__}); \ 157 return ®istration; \ 158 } 159 #endif 160 161 #define NN_REGISTER_OPERATION_DEFAULT_VALIDATION(identifier, prepare, execute, ...) \ 162 NN_VALIDATION_FUNCTION_SIGNATURE(identifier); \ 163 NN_REGISTER_OPERATION(identifier, #identifier, NN_VALIDATION_FUNCTION_NAME(identifier), \ 164 prepare, execute, __VA_ARGS__); 165 166 #define NN_OPERATION_IS_NOT_IMPLEMENTED(identifier) \ 167 const OperationRegistration* register_##identifier() { return nullptr; } 168 169 } // namespace nn 170 } // namespace android 171 172 #endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H 173