1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ 16 #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ 17 18 #include "tensorflow/lite/c/common.h" 19 #include "tensorflow/lite/core/api/op_resolver.h" 20 #include "tensorflow/lite/micro/compatibility.h" 21 #include "tensorflow/lite/schema/schema_generated.h" 22 23 #ifndef TFLITE_REGISTRATIONS_MAX 24 #define TFLITE_REGISTRATIONS_MAX (128) 25 #endif 26 27 namespace tflite { 28 29 // Op versions discussed in this file are enumerated here: 30 // tensorflow/lite/tools/versioning/op_version.cc 31 32 template <unsigned int tOpCount = TFLITE_REGISTRATIONS_MAX> 33 class MicroOpResolver : public OpResolver { 34 public: FindOp(tflite::BuiltinOperator op,int version)35 const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 36 int version) const override { 37 for (unsigned int i = 0; i < registrations_len_; ++i) { 38 const TfLiteRegistration& registration = registrations_[i]; 39 if ((registration.builtin_code == op) && 40 (registration.version == version)) { 41 return ®istration; 42 } 43 } 44 return nullptr; 45 } 46 FindOp(const char * op,int version)47 const TfLiteRegistration* FindOp(const char* op, int version) const override { 48 for (unsigned int i = 0; i < registrations_len_; ++i) { 49 const TfLiteRegistration& registration = registrations_[i]; 50 if ((registration.builtin_code == BuiltinOperator_CUSTOM) && 51 (strcmp(registration.custom_name, op) == 0) && 52 (registration.version == version)) { 53 return ®istration; 54 } 55 } 56 return nullptr; 57 } 58 59 void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, 60 int min_version = 1, int max_version = 1) { 61 for (int version = min_version; version <= max_version; ++version) { 62 if (registrations_len_ >= tOpCount) { 63 // TODO(b/147748244) - Add error reporting hooks so we can report this! 64 return; 65 } 66 TfLiteRegistration* new_registration = 67 ®istrations_[registrations_len_]; 68 registrations_len_ += 1; 69 70 *new_registration = *registration; 71 new_registration->builtin_code = op; 72 new_registration->version = version; 73 } 74 } 75 76 void AddCustom(const char* name, TfLiteRegistration* registration, 77 int min_version = 1, int max_version = 1) { 78 for (int version = min_version; version <= max_version; ++version) { 79 if (registrations_len_ >= tOpCount) { 80 // TODO(b/147748244) - Add error reporting hooks so we can report this! 81 return; 82 } 83 TfLiteRegistration* new_registration = 84 ®istrations_[registrations_len_]; 85 registrations_len_ += 1; 86 87 *new_registration = *registration; 88 new_registration->builtin_code = BuiltinOperator_CUSTOM; 89 new_registration->custom_name = name; 90 new_registration->version = version; 91 } 92 } 93 GetRegistrationLength()94 unsigned int GetRegistrationLength() { return registrations_len_; } 95 96 private: 97 TfLiteRegistration registrations_[tOpCount]; 98 unsigned int registrations_len_ = 0; 99 100 TF_LITE_REMOVE_VIRTUAL_DELETE 101 }; 102 103 // TODO(b/147854028): Consider switching all uses of MicroMutableOpResolver to 104 // MicroOpResolver. 105 class MicroMutableOpResolver 106 : public MicroOpResolver<TFLITE_REGISTRATIONS_MAX> { 107 private: 108 TF_LITE_REMOVE_VIRTUAL_DELETE 109 }; 110 111 }; // namespace tflite 112 113 #endif // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_ 114