1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/mutable_op_resolver.h"
17
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/core/api/op_resolver_internal.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25
26 namespace tflite {
27
FindOp(tflite::BuiltinOperator op,int version) const28 const TfLiteRegistration* MutableOpResolver::FindOp(tflite::BuiltinOperator op,
29 int version) const {
30 auto it = builtins_.find(std::make_pair(op, version));
31 if (it != builtins_.end()) {
32 return &it->second;
33 }
34 for (const OpResolver* other : other_op_resolvers_) {
35 const TfLiteRegistration* result = other->FindOp(op, version);
36 if (result != nullptr) {
37 return result;
38 }
39 }
40 return nullptr;
41 }
42
FindOp(const char * op,int version) const43 const TfLiteRegistration* MutableOpResolver::FindOp(const char* op,
44 int version) const {
45 auto it = custom_ops_.find(std::make_pair(op, version));
46 if (it != custom_ops_.end()) {
47 return &it->second;
48 }
49 for (const OpResolver* other : other_op_resolvers_) {
50 const TfLiteRegistration* result = other->FindOp(op, version);
51 if (result != nullptr) {
52 return result;
53 }
54 }
55 return nullptr;
56 }
57
AddBuiltin(tflite::BuiltinOperator op,const TfLiteRegistration * registration,int version)58 void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
59 const TfLiteRegistration* registration,
60 int version) {
61 if (registration == nullptr) {
62 // Under certain conditions, builtin TfLiteRegistration factory methods may
63 // return null in the client library. This is generally benign, and we
64 // silently suppress resulting AddBuiltin calls here.
65 return;
66 }
67 TfLiteRegistration new_registration = *registration;
68 new_registration.custom_name = nullptr;
69 new_registration.builtin_code = op;
70 new_registration.version = version;
71 auto op_key = std::make_pair(op, version);
72 builtins_[op_key] = new_registration;
73 // The builtin op that is being added may be one that is not supported by
74 // tflite::ops::builtin::BuiltinOpResolver. Or the TfLiteRegistration for this
75 // builtin may be different than the one that BuiltinOpResolver would use,
76 // which could lead to different semantics. Both of those cases are considered
77 // "user defined ops".
78 may_directly_contain_user_defined_ops_ = true;
79 }
80
AddBuiltin(tflite::BuiltinOperator op,const TfLiteRegistration * registration,int min_version,int max_version)81 void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
82 const TfLiteRegistration* registration,
83 int min_version, int max_version) {
84 for (int version = min_version; version <= max_version; ++version) {
85 AddBuiltin(op, registration, version);
86 }
87 }
88
AddCustom(const char * name,const TfLiteRegistration * registration,int version)89 void MutableOpResolver::AddCustom(const char* name,
90 const TfLiteRegistration* registration,
91 int version) {
92 TfLiteRegistration new_registration = *registration;
93 new_registration.builtin_code = BuiltinOperator_CUSTOM;
94 new_registration.custom_name = name;
95 new_registration.version = version;
96 auto op_key = std::make_pair(name, version);
97 custom_ops_[op_key] = new_registration;
98 may_directly_contain_user_defined_ops_ = true;
99 }
100
AddCustom(const char * name,const TfLiteRegistration * registration,int min_version,int max_version)101 void MutableOpResolver::AddCustom(const char* name,
102 const TfLiteRegistration* registration,
103 int min_version, int max_version) {
104 for (int version = min_version; version <= max_version; ++version) {
105 AddCustom(name, registration, version);
106 }
107 }
108
AddAll(const MutableOpResolver & other)109 void MutableOpResolver::AddAll(const MutableOpResolver& other) {
110 // map::insert does not replace existing elements, and map::insert_or_assign
111 // wasn't added until C++17.
112 for (const auto& other_builtin : other.builtins_) {
113 builtins_[other_builtin.first] = other_builtin.second;
114 }
115 for (const auto& other_custom_op : other.custom_ops_) {
116 custom_ops_[other_custom_op.first] = other_custom_op.second;
117 }
118 other_op_resolvers_.insert(other_op_resolvers_.begin(),
119 other.other_op_resolvers_.begin(),
120 other.other_op_resolvers_.end());
121 }
122
ChainOpResolver(const OpResolver * other)123 void MutableOpResolver::ChainOpResolver(const OpResolver* other) {
124 other_op_resolvers_.push_back(other);
125 }
126
MayContainUserDefinedOps() const127 bool MutableOpResolver::MayContainUserDefinedOps() const {
128 if (may_directly_contain_user_defined_ops_) {
129 return true;
130 }
131 for (const OpResolver* other : other_op_resolvers_) {
132 if (OpResolverInternal::MayContainUserDefinedOps(*other)) {
133 return true;
134 }
135 }
136 return false;
137 }
138
139 } // namespace tflite
140