• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
19 
20 #include <utility>
21 
22 #include "OperationsUtils.h"
23 
24 namespace android {
25 namespace nn {
26 
27 // Encapsulates an operation implementation.
28 struct OperationRegistration {
29     OperationType type;
30     const char* name;
31 
32     // Validates operand types, shapes, and any values known during graph creation.
33     std::function<Result<Version>(const IOperationValidationContext*)> validate;
34 
35     // prepare is called when the inputs this operation depends on have been
36     // computed. Typically, prepare does any remaining validation and sets
37     // output shapes via context->setOutputShape(...).
38     std::function<bool(IOperationExecutionContext*)> prepare;
39 
40     // Executes the operation, reading from context->getInputBuffer(...)
41     // and writing to context->getOutputBuffer(...).
42     std::function<bool(IOperationExecutionContext*)> execute;
43 
44     struct Flag {
45         // Whether the operation allows at least one operand to be omitted.
46         bool allowOmittedOperand = false;
47         // Whether the operation allows at least one input operand to be a zero-sized tensor.
48         bool allowZeroSizedInput = false;
49     } flags;
50 
OperationRegistrationOperationRegistration51     OperationRegistration(
52             OperationType type, const char* name,
53             std::function<Result<Version>(const IOperationValidationContext*)> validate,
54             std::function<bool(IOperationExecutionContext*)> prepare,
55             std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
56         : type(type),
57           name(name),
58           validate(std::move(validate)),
59           prepare(std::move(prepare)),
60           execute(std::move(execute)),
61           flags(flags) {}
62 };
63 
64 // A registry of operation implementations.
65 class IOperationResolver {
66    public:
67     virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
~IOperationResolver()68     virtual ~IOperationResolver() {}
69 };
70 
71 // A registry of builtin operation implementations.
72 //
73 // Note that some operations bypass BuiltinOperationResolver (b/124041202).
74 //
75 // Usage:
76 //   const OperationRegistration* operationRegistration =
77 //           BuiltinOperationResolver::get()->findOperation(operationType);
78 //   NN_RET_CHECK(operationRegistration != nullptr);
79 //   NN_RET_CHECK(operationRegistration->validate != nullptr);
80 //   NN_RET_CHECK(operationRegistration->validate(&context));
81 //
82 class BuiltinOperationResolver : public IOperationResolver {
83     DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
84 
85    public:
get()86     static const BuiltinOperationResolver* get() {
87         static BuiltinOperationResolver instance;
88         return &instance;
89     }
90 
91     const OperationRegistration* findOperation(OperationType operationType) const override;
92 
93     // The number of operation types (OperationCode) defined in NeuralNetworks.h.
94     static constexpr int kNumberOfOperationTypes = 102;
95 
96    private:
97     BuiltinOperationResolver();
98 
99     void registerOperation(const OperationRegistration* operationRegistration);
100 
101     const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
102 };
103 
104 // NN_REGISTER_OPERATION creates OperationRegistration for consumption by
105 // OperationResolver.
106 //
107 // Usage:
108 // (check OperationRegistration::Flag for available fields and default values.)
109 //
110 // - With default flags.
111 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
112 //                         foo_op::prepare, foo_op::execute);
113 //
114 // - With a customized flag.
115 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
116 //                         foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
117 //
118 // - With multiple customized flags.
119 //   NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
120 //                         foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
121 //                         .allowZeroSizedInput = true);
122 //
123 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
124 #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...)     \
125     const OperationRegistration* register_##identifier() {                                    \
126         static OperationRegistration registration(OperationType::identifier, operationName,   \
127                                                   validate, prepare, execute, {__VA_ARGS__}); \
128         return &registration;                                                                 \
129     }
130 #else
131 // This version ignores CPU execution logic (prepare and execute).
132 // The compiler is supposed to omit that code so that only validation logic
133 // makes it into libneuralnetworks_utils.
134 #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
135                               ...)                                                                 \
136     const OperationRegistration* register_##identifier() {                                         \
137         static OperationRegistration registration(OperationType::identifier, operationName,        \
138                                                   validate, nullptr, nullptr, {__VA_ARGS__});      \
139         return &registration;                                                                      \
140     }
141 #endif
142 
143 }  // namespace nn
144 }  // namespace android
145 
146 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
147