1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "OperationResolver" 18 19 #include "OperationResolver.h" 20 21 #include "NeuralNetworks.h" 22 23 namespace android { 24 namespace nn { 25 26 // TODO(b/119608412): Find a way to not reference every operation here. 27 const OperationRegistration* register_ABS(); 28 const OperationRegistration* register_ADD(); 29 const OperationRegistration* register_AVERAGE_POOL_2D(); 30 const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM(); 31 const OperationRegistration* register_BIDIRECTIONAL_SEQUENCE_RNN(); 32 const OperationRegistration* register_BOX_WITH_NMS_LIMIT(); 33 const OperationRegistration* register_CHANNEL_SHUFFLE(); 34 const OperationRegistration* register_CONCATENATION(); 35 const OperationRegistration* register_CONV_2D(); 36 const OperationRegistration* register_DEPTHWISE_CONV_2D(); 37 const OperationRegistration* register_DEQUANTIZE(); 38 const OperationRegistration* register_DETECTION_POSTPROCESSING(); 39 const OperationRegistration* register_DIV(); 40 const OperationRegistration* register_ELU(); 41 const OperationRegistration* register_EQUAL(); 42 const OperationRegistration* register_EXP(); 43 const OperationRegistration* register_FILL(); 44 const OperationRegistration* register_FLOOR(); 45 const OperationRegistration* register_FULLY_CONNECTED(); 46 const OperationRegistration* register_GATHER(); 47 const OperationRegistration* register_GENERATE_PROPOSALS(); 48 const OperationRegistration* register_GREATER(); 49 const OperationRegistration* register_GREATER_EQUAL(); 50 const OperationRegistration* register_HARD_SWISH(); 51 const OperationRegistration* register_HEATMAP_MAX_KEYPOINT(); 52 const OperationRegistration* register_INSTANCE_NORMALIZATION(); 53 const OperationRegistration* register_L2_NORMALIZATION(); 54 const OperationRegistration* register_L2_POOL_2D(); 55 const OperationRegistration* register_LESS(); 56 const OperationRegistration* register_LESS_EQUAL(); 57 const OperationRegistration* register_LOCAL_RESPONSE_NORMALIZATION(); 58 const OperationRegistration* register_LOG(); 59 const OperationRegistration* register_LOGICAL_AND(); 60 const OperationRegistration* register_LOGICAL_NOT(); 61 const OperationRegistration* register_LOGICAL_OR(); 62 const OperationRegistration* register_LOGISTIC(); 63 const OperationRegistration* register_LOG_SOFTMAX(); 64 const OperationRegistration* register_MAX_POOL_2D(); 65 const OperationRegistration* register_MUL(); 66 const OperationRegistration* register_NEG(); 67 const OperationRegistration* register_NOT_EQUAL(); 68 const OperationRegistration* register_PRELU(); 69 const OperationRegistration* register_QUANTIZE(); 70 const OperationRegistration* register_QUANTIZED_LSTM(); 71 const OperationRegistration* register_RANK(); 72 const OperationRegistration* register_REDUCE_ALL(); 73 const OperationRegistration* register_REDUCE_ANY(); 74 const OperationRegistration* register_REDUCE_MAX(); 75 const OperationRegistration* register_REDUCE_MIN(); 76 const OperationRegistration* register_REDUCE_PROD(); 77 const OperationRegistration* register_REDUCE_SUM(); 78 const OperationRegistration* register_RELU(); 79 const OperationRegistration* register_RELU1(); 80 const OperationRegistration* register_RELU6(); 81 const OperationRegistration* register_RESIZE_BILINEAR(); 82 const OperationRegistration* register_RESIZE_NEAREST_NEIGHBOR(); 83 const OperationRegistration* register_ROI_ALIGN(); 84 const OperationRegistration* register_ROI_POOLING(); 85 const OperationRegistration* register_RSQRT(); 86 const OperationRegistration* register_SELECT(); 87 const OperationRegistration* register_SIN(); 88 const OperationRegistration* register_SLICE(); 89 const OperationRegistration* register_SOFTMAX(); 90 const OperationRegistration* register_SQRT(); 91 const OperationRegistration* register_SQUEEZE(); 92 const OperationRegistration* register_STRIDED_SLICE(); 93 const OperationRegistration* register_SUB(); 94 const OperationRegistration* register_TANH(); 95 const OperationRegistration* register_TOPK_V2(); 96 const OperationRegistration* register_TRANSPOSE(); 97 const OperationRegistration* register_TRANSPOSE_CONV_2D(); 98 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_LSTM(); 99 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_RNN(); 100 BuiltinOperationResolver()101BuiltinOperationResolver::BuiltinOperationResolver() { 102 registerOperation(register_ABS()); 103 registerOperation(register_ADD()); 104 registerOperation(register_AVERAGE_POOL_2D()); 105 registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM()); 106 registerOperation(register_BIDIRECTIONAL_SEQUENCE_RNN()); 107 registerOperation(register_BOX_WITH_NMS_LIMIT()); 108 registerOperation(register_CHANNEL_SHUFFLE()); 109 registerOperation(register_CONCATENATION()); 110 registerOperation(register_CONV_2D()); 111 registerOperation(register_DEPTHWISE_CONV_2D()); 112 registerOperation(register_DEQUANTIZE()); 113 registerOperation(register_DETECTION_POSTPROCESSING()); 114 registerOperation(register_DIV()); 115 registerOperation(register_ELU()); 116 registerOperation(register_EQUAL()); 117 registerOperation(register_EXP()); 118 registerOperation(register_FILL()); 119 registerOperation(register_FLOOR()); 120 registerOperation(register_FULLY_CONNECTED()); 121 registerOperation(register_GATHER()); 122 registerOperation(register_GENERATE_PROPOSALS()); 123 registerOperation(register_GREATER()); 124 registerOperation(register_GREATER_EQUAL()); 125 registerOperation(register_HARD_SWISH()); 126 registerOperation(register_HEATMAP_MAX_KEYPOINT()); 127 registerOperation(register_INSTANCE_NORMALIZATION()); 128 registerOperation(register_L2_NORMALIZATION()); 129 registerOperation(register_L2_POOL_2D()); 130 registerOperation(register_LESS()); 131 registerOperation(register_LESS_EQUAL()); 132 registerOperation(register_LOCAL_RESPONSE_NORMALIZATION()); 133 registerOperation(register_LOG()); 134 registerOperation(register_LOGICAL_AND()); 135 registerOperation(register_LOGICAL_NOT()); 136 registerOperation(register_LOGICAL_OR()); 137 registerOperation(register_LOGISTIC()); 138 registerOperation(register_LOG_SOFTMAX()); 139 registerOperation(register_MAX_POOL_2D()); 140 registerOperation(register_MUL()); 141 registerOperation(register_NEG()); 142 registerOperation(register_NOT_EQUAL()); 143 registerOperation(register_PRELU()); 144 registerOperation(register_QUANTIZE()); 145 registerOperation(register_QUANTIZED_LSTM()); 146 registerOperation(register_RANK()); 147 registerOperation(register_REDUCE_ALL()); 148 registerOperation(register_REDUCE_ANY()); 149 registerOperation(register_REDUCE_MAX()); 150 registerOperation(register_REDUCE_MIN()); 151 registerOperation(register_REDUCE_PROD()); 152 registerOperation(register_REDUCE_SUM()); 153 registerOperation(register_RELU()); 154 registerOperation(register_RELU1()); 155 registerOperation(register_RELU6()); 156 registerOperation(register_RESIZE_BILINEAR()); 157 registerOperation(register_RESIZE_NEAREST_NEIGHBOR()); 158 registerOperation(register_ROI_ALIGN()); 159 registerOperation(register_ROI_POOLING()); 160 registerOperation(register_RSQRT()); 161 registerOperation(register_SELECT()); 162 registerOperation(register_SIN()); 163 registerOperation(register_SLICE()); 164 registerOperation(register_SOFTMAX()); 165 registerOperation(register_SQRT()); 166 registerOperation(register_SQUEEZE()); 167 registerOperation(register_STRIDED_SLICE()); 168 registerOperation(register_SUB()); 169 registerOperation(register_TANH()); 170 registerOperation(register_TOPK_V2()); 171 registerOperation(register_TRANSPOSE()); 172 registerOperation(register_TRANSPOSE_CONV_2D()); 173 registerOperation(register_UNIDIRECTIONAL_SEQUENCE_LSTM()); 174 registerOperation(register_UNIDIRECTIONAL_SEQUENCE_RNN()); 175 } 176 findOperation(OperationType operationType) const177const OperationRegistration* BuiltinOperationResolver::findOperation( 178 OperationType operationType) const { 179 auto index = static_cast<int32_t>(operationType); 180 if (index < 0 || index >= kNumberOfOperationTypes) { 181 return nullptr; 182 } 183 return mRegistrations[index]; 184 } 185 registerOperation(const OperationRegistration * operationRegistration)186void BuiltinOperationResolver::registerOperation( 187 const OperationRegistration* operationRegistration) { 188 CHECK(operationRegistration != nullptr); 189 auto index = static_cast<int32_t>(operationRegistration->type); 190 CHECK_LE(0, index); 191 CHECK_LT(index, kNumberOfOperationTypes); 192 CHECK(mRegistrations[index] == nullptr); 193 mRegistrations[index] = operationRegistration; 194 } 195 196 } // namespace nn 197 } // namespace android 198