1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ 16 #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ 17 18 #include <cstdio> 19 #include <cstring> 20 21 #include "tensorflow/lite/c/common.h" 22 #include "tensorflow/lite/core/api/error_reporter.h" 23 #include "tensorflow/lite/core/api/flatbuffer_conversions.h" 24 #include "tensorflow/lite/kernels/internal/compatibility.h" 25 #include "tensorflow/lite/kernels/op_macros.h" 26 #include "tensorflow/lite/micro/compatibility.h" 27 #include "tensorflow/lite/micro/kernels/ethosu.h" 28 #include "tensorflow/lite/micro/kernels/fully_connected.h" 29 #include "tensorflow/lite/micro/kernels/micro_ops.h" 30 #include "tensorflow/lite/micro/micro_op_resolver.h" 31 #include "tensorflow/lite/schema/schema_generated.h" 32 33 namespace tflite { 34 TfLiteRegistration* Register_DETECTION_POSTPROCESS(); 35 36 template <unsigned int tOpCount> 37 class MicroMutableOpResolver : public MicroOpResolver { 38 public: 39 TF_LITE_REMOVE_VIRTUAL_DELETE 40 41 explicit MicroMutableOpResolver(ErrorReporter* error_reporter = nullptr) error_reporter_(error_reporter)42 : error_reporter_(error_reporter) {} 43 FindOp(tflite::BuiltinOperator op)44 const TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { 45 if (op == BuiltinOperator_CUSTOM) return nullptr; 46 47 for (unsigned int i = 0; i < registrations_len_; ++i) { 48 const TfLiteRegistration& registration = registrations_[i]; 49 if (registration.builtin_code == op) { 50 return ®istration; 51 } 52 } 53 return nullptr; 54 } 55 FindOp(const char * op)56 const TfLiteRegistration* FindOp(const char* op) const override { 57 for (unsigned int i = 0; i < registrations_len_; ++i) { 58 const TfLiteRegistration& registration = registrations_[i]; 59 if ((registration.builtin_code == BuiltinOperator_CUSTOM) && 60 (strcmp(registration.custom_name, op) == 0)) { 61 return ®istration; 62 } 63 } 64 return nullptr; 65 } 66 GetOpDataParser(BuiltinOperator op)67 MicroOpResolver::BuiltinParseFunction GetOpDataParser( 68 BuiltinOperator op) const override { 69 TFLITE_DCHECK(num_buitin_ops_ <= tOpCount); 70 for (unsigned int i = 0; i < num_buitin_ops_; ++i) { 71 if (builtin_codes_[i] == op) return builtin_parsers_[i]; 72 } 73 return nullptr; 74 } 75 76 // Registers a Custom Operator with the MicroOpResolver. 77 // 78 // Only the first call for a given name will be successful. i.e. if this 79 // function is called again for a previously added Custom Operator, the 80 // MicroOpResolver will be unchanged and this function will return 81 // kTfLiteError. AddCustom(const char * name,TfLiteRegistration * registration)82 TfLiteStatus AddCustom(const char* name, TfLiteRegistration* registration) { 83 if (registrations_len_ >= tOpCount) { 84 if (error_reporter_) { 85 TF_LITE_REPORT_ERROR( 86 error_reporter_, 87 "Couldn't register custom op '%s', resolver size is too small (%d)", 88 name, tOpCount); 89 } 90 return kTfLiteError; 91 } 92 93 if (FindOp(name) != nullptr) { 94 if (error_reporter_ != nullptr) { 95 TF_LITE_REPORT_ERROR(error_reporter_, 96 "Calling AddCustom for the same op more than once " 97 "is not supported (Op: %s).", 98 name); 99 } 100 return kTfLiteError; 101 } 102 103 TfLiteRegistration* new_registration = ®istrations_[registrations_len_]; 104 registrations_len_ += 1; 105 106 *new_registration = *registration; 107 new_registration->builtin_code = BuiltinOperator_CUSTOM; 108 new_registration->custom_name = name; 109 return kTfLiteOk; 110 } 111 112 // The Add* functions below add the various Builtin operators to the 113 // MicroMutableOpResolver object. 114 AddAbs()115 TfLiteStatus AddAbs() { 116 return AddBuiltin(BuiltinOperator_ABS, tflite::ops::micro::Register_ABS(), 117 ParseAbs); 118 } 119 AddAdd()120 TfLiteStatus AddAdd() { 121 return AddBuiltin(BuiltinOperator_ADD, tflite::ops::micro::Register_ADD(), 122 ParseAdd); 123 } 124 AddArgMax()125 TfLiteStatus AddArgMax() { 126 return AddBuiltin(BuiltinOperator_ARG_MAX, 127 tflite::ops::micro::Register_ARG_MAX(), ParseArgMax); 128 } 129 AddArgMin()130 TfLiteStatus AddArgMin() { 131 return AddBuiltin(BuiltinOperator_ARG_MIN, 132 tflite::ops::micro::Register_ARG_MIN(), ParseArgMin); 133 } 134 AddAveragePool2D()135 TfLiteStatus AddAveragePool2D() { 136 return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, 137 tflite::ops::micro::Register_AVERAGE_POOL_2D(), 138 ParsePool); 139 } 140 AddBatchToSpaceND()141 TfLiteStatus AddBatchToSpaceND() { 142 return AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, 143 Register_BATCH_TO_SPACE_ND(), ParseBatchToSpaceNd); 144 } 145 AddCast()146 TfLiteStatus AddCast() { 147 return AddBuiltin(BuiltinOperator_CAST, Register_CAST(), ParseCast); 148 } 149 AddCeil()150 TfLiteStatus AddCeil() { 151 return AddBuiltin(BuiltinOperator_CEIL, tflite::ops::micro::Register_CEIL(), 152 ParseCeil); 153 } 154 AddCircularBuffer()155 TfLiteStatus AddCircularBuffer() { 156 return AddCustom("CIRCULAR_BUFFER", 157 tflite::ops::micro::Register_CIRCULAR_BUFFER()); 158 } 159 AddConcatenation()160 TfLiteStatus AddConcatenation() { 161 return AddBuiltin(BuiltinOperator_CONCATENATION, 162 tflite::ops::micro::Register_CONCATENATION(), 163 ParseConcatenation); 164 } 165 AddConv2D()166 TfLiteStatus AddConv2D() { 167 return AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), ParseConv2D); 168 } 169 AddCos()170 TfLiteStatus AddCos() { 171 return AddBuiltin(BuiltinOperator_COS, tflite::ops::micro::Register_COS(), 172 ParseCos); 173 } 174 AddDepthwiseConv2D()175 TfLiteStatus AddDepthwiseConv2D() { 176 return AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, 177 Register_DEPTHWISE_CONV_2D(), ParseDepthwiseConv2D); 178 } 179 AddDequantize()180 TfLiteStatus AddDequantize() { 181 return AddBuiltin(BuiltinOperator_DEQUANTIZE, 182 tflite::ops::micro::Register_DEQUANTIZE(), 183 ParseDequantize); 184 } 185 AddDetectionPostprocess()186 TfLiteStatus AddDetectionPostprocess() { 187 return AddCustom("TFLite_Detection_PostProcess", 188 tflite::Register_DETECTION_POSTPROCESS()); 189 } 190 AddEqual()191 TfLiteStatus AddEqual() { 192 return AddBuiltin(BuiltinOperator_EQUAL, 193 tflite::ops::micro::Register_EQUAL(), ParseEqual); 194 } 195 AddEthosU()196 TfLiteStatus AddEthosU() { 197 TfLiteRegistration* registration = tflite::Register_ETHOSU(); 198 if (registration) { 199 return AddCustom(tflite::GetString_ETHOSU(), registration); 200 } 201 return kTfLiteOk; 202 } 203 AddExp()204 TfLiteStatus AddExp() { 205 return AddBuiltin(BuiltinOperator_EXP, Register_EXP(), ParseExp); 206 } 207 AddFloor()208 TfLiteStatus AddFloor() { 209 return AddBuiltin(BuiltinOperator_FLOOR, 210 tflite::ops::micro::Register_FLOOR(), ParseFloor); 211 } 212 213 TfLiteStatus AddFullyConnected( 214 const TfLiteRegistration& registration = Register_FULLY_CONNECTED()) { 215 return AddBuiltin(BuiltinOperator_FULLY_CONNECTED, registration, 216 ParseFullyConnected); 217 } 218 AddGreater()219 TfLiteStatus AddGreater() { 220 return AddBuiltin(BuiltinOperator_GREATER, 221 tflite::ops::micro::Register_GREATER(), ParseGreater); 222 } 223 AddGreaterEqual()224 TfLiteStatus AddGreaterEqual() { 225 return AddBuiltin(BuiltinOperator_GREATER_EQUAL, 226 tflite::ops::micro::Register_GREATER_EQUAL(), 227 ParseGreaterEqual); 228 } 229 AddHardSwish()230 TfLiteStatus AddHardSwish() { 231 return AddBuiltin(BuiltinOperator_HARD_SWISH, 232 tflite::ops::micro::Register_HARD_SWISH(), 233 ParseHardSwish); 234 } 235 AddL2Normalization()236 TfLiteStatus AddL2Normalization() { 237 return AddBuiltin(BuiltinOperator_L2_NORMALIZATION, 238 tflite::ops::micro::Register_L2_NORMALIZATION(), 239 ParseL2Normalization); 240 } 241 AddLess()242 TfLiteStatus AddLess() { 243 return AddBuiltin(BuiltinOperator_LESS, tflite::ops::micro::Register_LESS(), 244 ParseLess); 245 } 246 AddLessEqual()247 TfLiteStatus AddLessEqual() { 248 return AddBuiltin(BuiltinOperator_LESS_EQUAL, 249 tflite::ops::micro::Register_LESS_EQUAL(), 250 ParseLessEqual); 251 } 252 AddLog()253 TfLiteStatus AddLog() { 254 return AddBuiltin(BuiltinOperator_LOG, tflite::ops::micro::Register_LOG(), 255 ParseLog); 256 } 257 AddLogicalAnd()258 TfLiteStatus AddLogicalAnd() { 259 return AddBuiltin(BuiltinOperator_LOGICAL_AND, 260 tflite::ops::micro::Register_LOGICAL_AND(), 261 ParseLogicalAnd); 262 } 263 AddLogicalNot()264 TfLiteStatus AddLogicalNot() { 265 return AddBuiltin(BuiltinOperator_LOGICAL_NOT, 266 tflite::ops::micro::Register_LOGICAL_NOT(), 267 ParseLogicalNot); 268 } 269 AddLogicalOr()270 TfLiteStatus AddLogicalOr() { 271 return AddBuiltin(BuiltinOperator_LOGICAL_OR, 272 tflite::ops::micro::Register_LOGICAL_OR(), 273 ParseLogicalOr); 274 } 275 AddLogistic()276 TfLiteStatus AddLogistic() { 277 return AddBuiltin(BuiltinOperator_LOGISTIC, 278 tflite::ops::micro::Register_LOGISTIC(), ParseLogistic); 279 } 280 AddMaximum()281 TfLiteStatus AddMaximum() { 282 return AddBuiltin(BuiltinOperator_MAXIMUM, 283 tflite::ops::micro::Register_MAXIMUM(), ParseMaximum); 284 } 285 AddMaxPool2D()286 TfLiteStatus AddMaxPool2D() { 287 return AddBuiltin(BuiltinOperator_MAX_POOL_2D, 288 tflite::ops::micro::Register_MAX_POOL_2D(), ParsePool); 289 } 290 AddMean()291 TfLiteStatus AddMean() { 292 return AddBuiltin(BuiltinOperator_MEAN, tflite::ops::micro::Register_MEAN(), 293 ParseReducer); 294 } 295 AddMinimum()296 TfLiteStatus AddMinimum() { 297 return AddBuiltin(BuiltinOperator_MINIMUM, 298 tflite::ops::micro::Register_MINIMUM(), ParseMinimum); 299 } 300 AddMul()301 TfLiteStatus AddMul() { 302 return AddBuiltin(BuiltinOperator_MUL, tflite::ops::micro::Register_MUL(), 303 ParseMul); 304 } 305 AddNeg()306 TfLiteStatus AddNeg() { 307 return AddBuiltin(BuiltinOperator_NEG, tflite::ops::micro::Register_NEG(), 308 ParseNeg); 309 } 310 AddNotEqual()311 TfLiteStatus AddNotEqual() { 312 return AddBuiltin(BuiltinOperator_NOT_EQUAL, 313 tflite::ops::micro::Register_NOT_EQUAL(), ParseNotEqual); 314 } 315 AddPack()316 TfLiteStatus AddPack() { 317 return AddBuiltin(BuiltinOperator_PACK, tflite::ops::micro::Register_PACK(), 318 ParsePack); 319 } 320 AddPad()321 TfLiteStatus AddPad() { 322 return AddBuiltin(BuiltinOperator_PAD, tflite::ops::micro::Register_PAD(), 323 ParsePad); 324 } 325 AddPadV2()326 TfLiteStatus AddPadV2() { 327 return AddBuiltin(BuiltinOperator_PADV2, 328 tflite::ops::micro::Register_PADV2(), ParsePadV2); 329 } 330 AddPrelu()331 TfLiteStatus AddPrelu() { 332 return AddBuiltin(BuiltinOperator_PRELU, 333 tflite::ops::micro::Register_PRELU(), ParsePrelu); 334 } 335 AddQuantize()336 TfLiteStatus AddQuantize() { 337 return AddBuiltin(BuiltinOperator_QUANTIZE, Register_QUANTIZE(), 338 ParseQuantize); 339 } 340 AddReduceMax()341 TfLiteStatus AddReduceMax() { 342 return AddBuiltin(BuiltinOperator_REDUCE_MAX, 343 tflite::ops::micro::Register_REDUCE_MAX(), ParseReducer); 344 } 345 AddRelu()346 TfLiteStatus AddRelu() { 347 return AddBuiltin(BuiltinOperator_RELU, tflite::ops::micro::Register_RELU(), 348 ParseRelu); 349 } 350 AddRelu6()351 TfLiteStatus AddRelu6() { 352 return AddBuiltin(BuiltinOperator_RELU6, 353 tflite::ops::micro::Register_RELU6(), ParseRelu6); 354 } 355 AddReshape()356 TfLiteStatus AddReshape() { 357 return AddBuiltin(BuiltinOperator_RESHAPE, 358 tflite::ops::micro::Register_RESHAPE(), ParseReshape); 359 } 360 AddResizeNearestNeighbor()361 TfLiteStatus AddResizeNearestNeighbor() { 362 return AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 363 tflite::ops::micro::Register_RESIZE_NEAREST_NEIGHBOR(), 364 ParseResizeNearestNeighbor); 365 } 366 AddRound()367 TfLiteStatus AddRound() { 368 return AddBuiltin(BuiltinOperator_ROUND, 369 tflite::ops::micro::Register_ROUND(), ParseRound); 370 } 371 AddRsqrt()372 TfLiteStatus AddRsqrt() { 373 return AddBuiltin(BuiltinOperator_RSQRT, 374 tflite::ops::micro::Register_RSQRT(), ParseRsqrt); 375 } 376 AddShape()377 TfLiteStatus AddShape() { 378 return AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE(), ParseShape); 379 } 380 AddSin()381 TfLiteStatus AddSin() { 382 return AddBuiltin(BuiltinOperator_SIN, tflite::ops::micro::Register_SIN(), 383 ParseSin); 384 } 385 AddSoftmax()386 TfLiteStatus AddSoftmax() { 387 return AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(), 388 ParseSoftmax); 389 } 390 AddSpaceToBatchNd()391 TfLiteStatus AddSpaceToBatchNd() { 392 return AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, 393 Register_SPACE_TO_BATCH_ND(), ParseSpaceToBatchNd); 394 } 395 AddSplit()396 TfLiteStatus AddSplit() { 397 return AddBuiltin(BuiltinOperator_SPLIT, 398 tflite::ops::micro::Register_SPLIT(), ParseSplit); 399 } 400 AddSplitV()401 TfLiteStatus AddSplitV() { 402 return AddBuiltin(BuiltinOperator_SPLIT_V, 403 tflite::ops::micro::Register_SPLIT_V(), ParseSplitV); 404 } 405 AddSqrt()406 TfLiteStatus AddSqrt() { 407 return AddBuiltin(BuiltinOperator_SQRT, tflite::ops::micro::Register_SQRT(), 408 ParseSqrt); 409 } 410 AddSquare()411 TfLiteStatus AddSquare() { 412 return AddBuiltin(BuiltinOperator_SQUARE, 413 tflite::ops::micro::Register_SQUARE(), ParseSquare); 414 } 415 AddStridedSlice()416 TfLiteStatus AddStridedSlice() { 417 return AddBuiltin(BuiltinOperator_STRIDED_SLICE, 418 tflite::ops::micro::Register_STRIDED_SLICE(), 419 ParseStridedSlice); 420 } 421 AddSub()422 TfLiteStatus AddSub() { 423 return AddBuiltin(BuiltinOperator_SUB, tflite::ops::micro::Register_SUB(), 424 ParseSub); 425 } 426 AddSvdf()427 TfLiteStatus AddSvdf() { 428 return AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(), ParseSvdf); 429 } 430 AddTanh()431 TfLiteStatus AddTanh() { 432 return AddBuiltin(BuiltinOperator_TANH, tflite::ops::micro::Register_TANH(), 433 ParseTanh); 434 } 435 AddUnpack()436 TfLiteStatus AddUnpack() { 437 return AddBuiltin(BuiltinOperator_UNPACK, 438 tflite::ops::micro::Register_UNPACK(), ParseUnpack); 439 } 440 AddZerosLike()441 TfLiteStatus AddZerosLike() { 442 return AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE(), 443 ParseZerosLike); 444 } 445 GetRegistrationLength()446 unsigned int GetRegistrationLength() { return registrations_len_; } 447 448 private: AddBuiltin(tflite::BuiltinOperator op,const TfLiteRegistration & registration,MicroOpResolver::BuiltinParseFunction parser)449 TfLiteStatus AddBuiltin(tflite::BuiltinOperator op, 450 const TfLiteRegistration& registration, 451 MicroOpResolver::BuiltinParseFunction parser) { 452 if (op == BuiltinOperator_CUSTOM) { 453 if (error_reporter_ != nullptr) { 454 TF_LITE_REPORT_ERROR(error_reporter_, 455 "Invalid parameter BuiltinOperator_CUSTOM to the " 456 "AddBuiltin function."); 457 } 458 return kTfLiteError; 459 } 460 461 if (FindOp(op) != nullptr) { 462 if (error_reporter_ != nullptr) { 463 TF_LITE_REPORT_ERROR(error_reporter_, 464 "Calling AddBuiltin with the same op more than " 465 "once is not supported (Op: #%d).", 466 op); 467 } 468 return kTfLiteError; 469 } 470 471 if (registrations_len_ >= tOpCount) { 472 if (error_reporter_) { 473 TF_LITE_REPORT_ERROR(error_reporter_, 474 "Couldn't register builtin op #%d, resolver size " 475 "is too small (%d).", 476 op, tOpCount); 477 } 478 return kTfLiteError; 479 } 480 481 registrations_[registrations_len_] = registration; 482 // Strictly speaking, the builtin_code is not necessary for TFLM but filling 483 // it in regardless. 484 registrations_[registrations_len_].builtin_code = op; 485 registrations_len_++; 486 487 builtin_codes_[num_buitin_ops_] = op; 488 builtin_parsers_[num_buitin_ops_] = parser; 489 num_buitin_ops_++; 490 491 return kTfLiteOk; 492 } 493 494 TfLiteRegistration registrations_[tOpCount]; 495 unsigned int registrations_len_ = 0; 496 497 // Arrays (and counter) to store the builtin codes and their corresponding 498 // parse functions as these are registered with the Op Resolver. 499 BuiltinOperator builtin_codes_[tOpCount]; 500 MicroOpResolver::BuiltinParseFunction builtin_parsers_[tOpCount]; 501 unsigned int num_buitin_ops_ = 0; 502 503 ErrorReporter* error_reporter_; 504 }; 505 506 }; // namespace tflite 507 508 #endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ 509