• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/mutable_op_resolver.h"
17 
18 #include <string>
19 #include <unordered_map>
20 #include <utility>
21 
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/core/api/op_resolver_internal.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 
26 namespace tflite {
27 
FindOp(tflite::BuiltinOperator op,int version) const28 const TfLiteRegistration* MutableOpResolver::FindOp(tflite::BuiltinOperator op,
29                                                     int version) const {
30   auto it = builtins_.find(std::make_pair(op, version));
31   if (it != builtins_.end()) {
32     return &it->second;
33   }
34   for (const OpResolver* other : other_op_resolvers_) {
35     const TfLiteRegistration* result = other->FindOp(op, version);
36     if (result != nullptr) {
37       return result;
38     }
39   }
40   return nullptr;
41 }
42 
FindOp(const char * op,int version) const43 const TfLiteRegistration* MutableOpResolver::FindOp(const char* op,
44                                                     int version) const {
45   auto it = custom_ops_.find(std::make_pair(op, version));
46   if (it != custom_ops_.end()) {
47     return &it->second;
48   }
49   for (const OpResolver* other : other_op_resolvers_) {
50     const TfLiteRegistration* result = other->FindOp(op, version);
51     if (result != nullptr) {
52       return result;
53     }
54   }
55   return nullptr;
56 }
57 
AddBuiltin(tflite::BuiltinOperator op,const TfLiteRegistration * registration,int version)58 void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
59                                    const TfLiteRegistration* registration,
60                                    int version) {
61   if (registration == nullptr) {
62     // Under certain conditions, builtin TfLiteRegistration factory methods may
63     // return null in the client library. This is generally benign, and we
64     // silently suppress resulting AddBuiltin calls here.
65     return;
66   }
67   TfLiteRegistration new_registration = *registration;
68   new_registration.custom_name = nullptr;
69   new_registration.builtin_code = op;
70   new_registration.version = version;
71   auto op_key = std::make_pair(op, version);
72   builtins_[op_key] = new_registration;
73   // The builtin op that is being added may be one that is not supported by
74   // tflite::ops::builtin::BuiltinOpResolver. Or the TfLiteRegistration for this
75   // builtin may be different than the one that BuiltinOpResolver would use,
76   // which could lead to different semantics. Both of those cases are considered
77   // "user defined ops".
78   may_directly_contain_user_defined_ops_ = true;
79 }
80 
AddBuiltin(tflite::BuiltinOperator op,const TfLiteRegistration * registration,int min_version,int max_version)81 void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
82                                    const TfLiteRegistration* registration,
83                                    int min_version, int max_version) {
84   for (int version = min_version; version <= max_version; ++version) {
85     AddBuiltin(op, registration, version);
86   }
87 }
88 
AddCustom(const char * name,const TfLiteRegistration * registration,int version)89 void MutableOpResolver::AddCustom(const char* name,
90                                   const TfLiteRegistration* registration,
91                                   int version) {
92   TfLiteRegistration new_registration = *registration;
93   new_registration.builtin_code = BuiltinOperator_CUSTOM;
94   new_registration.custom_name = name;
95   new_registration.version = version;
96   auto op_key = std::make_pair(name, version);
97   custom_ops_[op_key] = new_registration;
98   may_directly_contain_user_defined_ops_ = true;
99 }
100 
AddCustom(const char * name,const TfLiteRegistration * registration,int min_version,int max_version)101 void MutableOpResolver::AddCustom(const char* name,
102                                   const TfLiteRegistration* registration,
103                                   int min_version, int max_version) {
104   for (int version = min_version; version <= max_version; ++version) {
105     AddCustom(name, registration, version);
106   }
107 }
108 
AddAll(const MutableOpResolver & other)109 void MutableOpResolver::AddAll(const MutableOpResolver& other) {
110   // map::insert does not replace existing elements, and map::insert_or_assign
111   // wasn't added until C++17.
112   for (const auto& other_builtin : other.builtins_) {
113     builtins_[other_builtin.first] = other_builtin.second;
114   }
115   for (const auto& other_custom_op : other.custom_ops_) {
116     custom_ops_[other_custom_op.first] = other_custom_op.second;
117   }
118   other_op_resolvers_.insert(other_op_resolvers_.begin(),
119                              other.other_op_resolvers_.begin(),
120                              other.other_op_resolvers_.end());
121 }
122 
ChainOpResolver(const OpResolver * other)123 void MutableOpResolver::ChainOpResolver(const OpResolver* other) {
124   other_op_resolvers_.push_back(other);
125 }
126 
MayContainUserDefinedOps() const127 bool MutableOpResolver::MayContainUserDefinedOps() const {
128   if (may_directly_contain_user_defined_ops_) {
129     return true;
130   }
131   for (const OpResolver* other : other_op_resolvers_) {
132     if (OpResolverInternal::MayContainUserDefinedOps(*other)) {
133       return true;
134     }
135   }
136   return false;
137 }
138 
139 }  // namespace tflite
140