1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H 19 20 #include <utility> 21 22 #include "OperationsUtils.h" 23 24 namespace android { 25 namespace nn { 26 27 // Encapsulates an operation implementation. 28 struct OperationRegistration { 29 OperationType type; 30 const char* name; 31 32 // Validates operand types, shapes, and any values known during graph creation. 33 std::function<Result<Version>(const IOperationValidationContext*)> validate; 34 35 // prepare is called when the inputs this operation depends on have been 36 // computed. Typically, prepare does any remaining validation and sets 37 // output shapes via context->setOutputShape(...). 38 std::function<bool(IOperationExecutionContext*)> prepare; 39 40 // Executes the operation, reading from context->getInputBuffer(...) 41 // and writing to context->getOutputBuffer(...). 42 std::function<bool(IOperationExecutionContext*)> execute; 43 44 struct Flag { 45 // Whether the operation allows at least one operand to be omitted. 46 bool allowOmittedOperand = false; 47 // Whether the operation allows at least one input operand to be a zero-sized tensor. 48 bool allowZeroSizedInput = false; 49 } flags; 50 OperationRegistrationOperationRegistration51 OperationRegistration( 52 OperationType type, const char* name, 53 std::function<Result<Version>(const IOperationValidationContext*)> validate, 54 std::function<bool(IOperationExecutionContext*)> prepare, 55 std::function<bool(IOperationExecutionContext*)> execute, Flag flags) 56 : type(type), 57 name(name), 58 validate(std::move(validate)), 59 prepare(std::move(prepare)), 60 execute(std::move(execute)), 61 flags(flags) {} 62 }; 63 64 // A registry of operation implementations. 65 class IOperationResolver { 66 public: 67 virtual const OperationRegistration* findOperation(OperationType operationType) const = 0; ~IOperationResolver()68 virtual ~IOperationResolver() {} 69 }; 70 71 // A registry of builtin operation implementations. 72 // 73 // Note that some operations bypass BuiltinOperationResolver (b/124041202). 74 // 75 // Usage: 76 // const OperationRegistration* operationRegistration = 77 // BuiltinOperationResolver::get()->findOperation(operationType); 78 // NN_RET_CHECK(operationRegistration != nullptr); 79 // NN_RET_CHECK(operationRegistration->validate != nullptr); 80 // NN_RET_CHECK(operationRegistration->validate(&context)); 81 // 82 class BuiltinOperationResolver : public IOperationResolver { 83 DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver); 84 85 public: get()86 static const BuiltinOperationResolver* get() { 87 static BuiltinOperationResolver instance; 88 return &instance; 89 } 90 91 const OperationRegistration* findOperation(OperationType operationType) const override; 92 93 // The number of operation types (OperationCode) defined in NeuralNetworks.h. 94 static constexpr int kNumberOfOperationTypes = 102; 95 96 private: 97 BuiltinOperationResolver(); 98 99 void registerOperation(const OperationRegistration* operationRegistration); 100 101 const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {}; 102 }; 103 104 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by 105 // OperationResolver. 106 // 107 // Usage: 108 // (check OperationRegistration::Flag for available fields and default values.) 109 // 110 // - With default flags. 111 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 112 // foo_op::prepare, foo_op::execute); 113 // 114 // - With a customized flag. 115 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 116 // foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true); 117 // 118 // - With multiple customized flags. 119 // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate, 120 // foo_op::prepare, foo_op::execute, .allowOmittedOperand = true, 121 // .allowZeroSizedInput = true); 122 // 123 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION 124 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \ 125 const OperationRegistration* register_##identifier() { \ 126 static OperationRegistration registration(OperationType::identifier, operationName, \ 127 validate, prepare, execute, {__VA_ARGS__}); \ 128 return ®istration; \ 129 } 130 #else 131 // This version ignores CPU execution logic (prepare and execute). 132 // The compiler is supposed to omit that code so that only validation logic 133 // makes it into libneuralnetworks_utils. 134 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \ 135 ...) \ 136 const OperationRegistration* register_##identifier() { \ 137 static OperationRegistration registration(OperationType::identifier, operationName, \ 138 validate, nullptr, nullptr, {__VA_ARGS__}); \ 139 return ®istration; \ 140 } 141 #endif 142 143 } // namespace nn 144 } // namespace android 145 146 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H 147