1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ 16 #define TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ 17 18 #include <stddef.h> 19 20 #include <string> 21 #include <unordered_map> 22 #include <utility> 23 24 #include "tensorflow/lite/c/common.h" 25 #include "tensorflow/lite/core/api/op_resolver.h" 26 #include "tensorflow/lite/schema/schema_generated.h" 27 #include "tensorflow/lite/util.h" 28 29 namespace tflite { 30 31 // Some versions of gcc don't support partial specialization in class scope, 32 // so these are defined in a namescope. 33 namespace op_resolver_hasher { 34 template <typename V> 35 struct ValueHasher { operatorValueHasher36 size_t operator()(const V& v) const { return std::hash<V>()(v); } 37 }; 38 39 template <> 40 struct ValueHasher<tflite::BuiltinOperator> { 41 size_t operator()(const tflite::BuiltinOperator& v) const { 42 return std::hash<int>()(static_cast<int>(v)); 43 } 44 }; 45 46 template <typename T> 47 struct OperatorKeyHasher { 48 size_t operator()(const T& x) const { 49 size_t a = ValueHasher<typename T::first_type>()(x.first); 50 size_t b = ValueHasher<typename T::second_type>()(x.second); 51 return CombineHashes({a, b}); 52 } 53 }; 54 } // namespace op_resolver_hasher 55 56 /// An OpResolver that is mutable, also used as the op in gen_op_registration. 57 /// A typical usage: 58 /// MutableOpResolver resolver; 59 /// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD()); 60 /// resolver.AddCustom("CustomOp", Register_CUSTOM_OP()); 61 /// InterpreterBuilder(model, resolver)(&interpreter); 62 class MutableOpResolver : public OpResolver { 63 public: 64 const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 65 int version) const override; 66 const TfLiteRegistration* FindOp(const char* op, int version) const override; 67 68 /// Registers the specified `version` of the specified builtin operator `op`. 69 /// Replaces any previous registration for the same operator version. 70 void AddBuiltin(tflite::BuiltinOperator op, 71 const TfLiteRegistration* registration, int version = 1); 72 73 /// Registers the specified version range (versions `min_version` to 74 /// `max_version`, inclusive) of the specified builtin operator `op`. 75 /// Replaces any previous registration for the same operator version. 76 void AddBuiltin(tflite::BuiltinOperator op, 77 const TfLiteRegistration* registration, int min_version, 78 int max_version); 79 80 /// Registers the specified `version` of the specified builtin operator `op`. 81 /// Replaces any previous registration for the same operator version. 82 void AddCustom(const char* name, const TfLiteRegistration* registration, 83 int version = 1); 84 85 /// Registers the specified version range (versions `min_version` to 86 /// `max_version`, inclusive) of the specified custom operator `name`. 87 /// Replaces any previous registration for the same operator version. 88 void AddCustom(const char* name, const TfLiteRegistration* registration, 89 int min_version, int max_version); 90 91 /// Registers all operator versions supported by another MutableOpResolver. 92 /// Replaces any previous registrations for the same operator versions, 93 /// except that registrations made with `AddBuiltin` or `AddCustom` always 94 /// take precedence over registrations made with `ChainOpResolver`. 95 void AddAll(const MutableOpResolver& other); 96 97 protected: 98 /// Registers all operator versions supported by another OpResolver, 99 /// except any already registered in this MutableOpResolver. 100 /// `other` must point to an OpResolver whose lifetime is at least as long 101 /// as the lifetime of the MutableOpResolver pointed to by `this`. 102 /// The OpResolver pointed to by `other` should not be modified during the 103 /// lifetime of this MutableOpResolver. 104 void ChainOpResolver(const OpResolver* other); 105 106 /// True if this OpResolver itself (as opposed to chained op resolvers 107 /// registed with ChainOpResolver) may contain user defined ops. 108 /// 109 /// By "user defined" ops, we mean any op definitions other than those 110 /// contained in tflite::ops::builtin::BuiltinOpResolver. 111 bool may_directly_contain_user_defined_ops_ = false; 112 113 private: 114 bool MayContainUserDefinedOps() const override; 115 116 typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey; 117 typedef std::pair<std::string, int> CustomOperatorKey; 118 119 std::unordered_map<BuiltinOperatorKey, TfLiteRegistration, 120 op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> > 121 builtins_; 122 std::unordered_map<CustomOperatorKey, TfLiteRegistration, 123 op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> > 124 custom_ops_; 125 std::vector<const OpResolver*> other_op_resolvers_; 126 }; 127 128 } // namespace tflite 129 130 #endif // TENSORFLOW_LITE_MUTABLE_OP_RESOLVER_H_ 131