• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H
19 
20 #include <utility>
21 
22 #include "OperationsExecutionUtils.h"
23 #include "OperationsValidationUtils.h"
24 
25 namespace android {
26 namespace nn {
27 
28 // Encapsulates an operation implementation.
29 struct OperationRegistration {
30     OperationType type;
31     const char* name;
32 
33     // Validates operand types, shapes, and any values known during graph creation.
34     // TODO(b/213938830): operation validation dispatch is duplicated and does not handle extension
35     // types.
36     std::function<Result<Version>(const IOperationValidationContext*)> validate;
37 
38     // prepare is called when the inputs this operation depends on have been
39     // computed. Typically, prepare does any remaining validation and sets
40     // output shapes via context->setOutputShape(...).
41     std::function<bool(IOperationExecutionContext*)> prepare;
42 
43     // Executes the operation, reading from context->getInputBuffer(...)
44     // and writing to context->getOutputBuffer(...).
45     std::function<bool(IOperationExecutionContext*)> execute;
46 
47     struct Flag {
48         // Whether the operation allows at least one operand to be omitted.
49         bool allowOmittedOperand = false;
50         // Whether the operation allows at least one input operand to be a zero-sized tensor.
51         bool allowZeroSizedInput = false;
52     } flags;
53 
OperationRegistrationOperationRegistration54     OperationRegistration(
55             OperationType type, const char* name,
56             std::function<Result<Version>(const IOperationValidationContext*)> validate,
57             std::function<bool(IOperationExecutionContext*)> prepare,
58             std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
59         : type(type),
60           name(name),
61           validate(std::move(validate)),
62           prepare(std::move(prepare)),
63           execute(std::move(execute)),
64           flags(flags) {}
65 };
66 
67 // A registry of operation implementations.
68 class IOperationResolver {
69    public:
70     virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
~IOperationResolver()71     virtual ~IOperationResolver() {}
72 };
73 
74 // A registry of builtin operation implementations.
75 //
76 // Note that some operations bypass BuiltinOperationResolver (b/124041202).
77 //
78 // Usage:
79 //   const OperationRegistration* operationRegistration =
80 //           BuiltinOperationResolver::get()->findOperation(operationType);
81 //   NN_RET_CHECK(operationRegistration != nullptr);
82 //   NN_RET_CHECK(operationRegistration->validate != nullptr);
83 //   NN_RET_CHECK(operationRegistration->validate(&context));
84 //
85 class BuiltinOperationResolver : public IOperationResolver {
86     DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
87 
88    public:
get()89     static const BuiltinOperationResolver* get() {
90         static BuiltinOperationResolver instance;
91         return &instance;
92     }
93 
94     const OperationRegistration* findOperation(OperationType operationType) const override;
95 
96     // The number of operation types (OperationCode) defined in NeuralNetworksTypes.h.
97     static constexpr int kNumberOfOperationTypes = 106;
98 
99 #ifdef NN_EXPERIMENTAL_FEATURE
100     // The number of experimental operation types (ANeuralNetworksExperimentalOperationCode) defined
101     // in NeuralNetworksExperimentalFeatures.h.
102     static constexpr int kNumberOfExperimentalOperationTypes = 1;
103 
104     // The starting value of experimental operation types (ANeuralNetworksExperimentalOperationCode)
105     // defined in NeuralNetworksExperimentalFeatures.h.
106     static constexpr int kStartOfExperimentalOperations = 20000;
107 #endif  // NN_EXPERIMENTAL_FEATURE
108 
109    private:
110     BuiltinOperationResolver();
111 
112     void registerOperation(const OperationRegistration* operationRegistration);
113 
114     const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
115 
116 #ifdef NN_EXPERIMENTAL_FEATURE
117     const OperationRegistration* mExperimentalRegistrations[kNumberOfExperimentalOperationTypes] =
118             {};
119 #endif  // NN_EXPERIMENTAL_FEATURE
120 };
121 
122 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by
123 // OperationResolver.
124 //
125 // Usage:
126 // (check OperationRegistration::Flag for available fields and default values.)
127 //
128 // - With default flags.
129 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
130 //                         foo_op::prepare, foo_op::execute);
131 //
132 // - With a customized flag.
133 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
134 //                         foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
135 //
136 // - With multiple customized flags.
137 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
138 //                         foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
139 //                         .allowZeroSizedInput = true);
140 //
141 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
142 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...)     \
143     const OperationRegistration* register_##identifier() {                                    \
144         static OperationRegistration registration(OperationType::identifier, operationName,   \
145                                                   validate, prepare, execute, {__VA_ARGS__}); \
146         return &registration;                                                                 \
147     }
148 #else
149 // This version ignores CPU execution logic (prepare and execute).
150 // The compiler is supposed to omit that code so that only validation logic
151 // makes it into libneuralnetworks_common*.
152 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
153                               ...)                                                                 \
154     const OperationRegistration* register_##identifier() {                                         \
155         static OperationRegistration registration(OperationType::identifier, operationName,        \
156                                                   validate, nullptr, nullptr, {__VA_ARGS__});      \
157         return &registration;                                                                      \
158     }
159 #endif
160 
161 #define NN_REGISTER_OPERATION_DEFAULT_VALIDATION(identifier, prepare, execute, ...)         \
162     NN_VALIDATION_FUNCTION_SIGNATURE(identifier);                                           \
163     NN_REGISTER_OPERATION(identifier, #identifier, NN_VALIDATION_FUNCTION_NAME(identifier), \
164                           prepare, execute, __VA_ARGS__);
165 
166 #define NN_OPERATION_IS_NOT_IMPLEMENTED(identifier) \
167     const OperationRegistration* register_##identifier() { return nullptr; }
168 
169 }  // namespace nn
170 }  // namespace android
171 
172 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_OPERATION_RESOLVER_H
173