• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
16 #define TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/core/api/op_resolver.h"
20 #include "tensorflow/lite/micro/compatibility.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22 
23 #ifndef TFLITE_REGISTRATIONS_MAX
24 #define TFLITE_REGISTRATIONS_MAX (128)
25 #endif
26 
27 namespace tflite {
28 
29 // Op versions discussed in this file are enumerated here:
30 // tensorflow/lite/tools/versioning/op_version.cc
31 
32 template <unsigned int tOpCount = TFLITE_REGISTRATIONS_MAX>
33 class MicroOpResolver : public OpResolver {
34  public:
FindOp(tflite::BuiltinOperator op,int version)35   const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
36                                    int version) const override {
37     for (unsigned int i = 0; i < registrations_len_; ++i) {
38       const TfLiteRegistration& registration = registrations_[i];
39       if ((registration.builtin_code == op) &&
40           (registration.version == version)) {
41         return &registration;
42       }
43     }
44     return nullptr;
45   }
46 
FindOp(const char * op,int version)47   const TfLiteRegistration* FindOp(const char* op, int version) const override {
48     for (unsigned int i = 0; i < registrations_len_; ++i) {
49       const TfLiteRegistration& registration = registrations_[i];
50       if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
51           (strcmp(registration.custom_name, op) == 0) &&
52           (registration.version == version)) {
53         return &registration;
54       }
55     }
56     return nullptr;
57   }
58 
59   void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
60                   int min_version = 1, int max_version = 1) {
61     for (int version = min_version; version <= max_version; ++version) {
62       if (registrations_len_ >= tOpCount) {
63         // TODO(b/147748244) - Add error reporting hooks so we can report this!
64         return;
65       }
66       TfLiteRegistration* new_registration =
67           &registrations_[registrations_len_];
68       registrations_len_ += 1;
69 
70       *new_registration = *registration;
71       new_registration->builtin_code = op;
72       new_registration->version = version;
73     }
74   }
75 
76   void AddCustom(const char* name, TfLiteRegistration* registration,
77                  int min_version = 1, int max_version = 1) {
78     for (int version = min_version; version <= max_version; ++version) {
79       if (registrations_len_ >= tOpCount) {
80         // TODO(b/147748244) - Add error reporting hooks so we can report this!
81         return;
82       }
83       TfLiteRegistration* new_registration =
84           &registrations_[registrations_len_];
85       registrations_len_ += 1;
86 
87       *new_registration = *registration;
88       new_registration->builtin_code = BuiltinOperator_CUSTOM;
89       new_registration->custom_name = name;
90       new_registration->version = version;
91     }
92   }
93 
GetRegistrationLength()94   unsigned int GetRegistrationLength() { return registrations_len_; }
95 
96  private:
97   TfLiteRegistration registrations_[tOpCount];
98   unsigned int registrations_len_ = 0;
99 
100   TF_LITE_REMOVE_VIRTUAL_DELETE
101 };
102 
103 // TODO(b/147854028): Consider switching all uses of MicroMutableOpResolver to
104 // MicroOpResolver.
105 class MicroMutableOpResolver
106     : public MicroOpResolver<TFLITE_REGISTRATIONS_MAX> {
107  private:
108   TF_LITE_REMOVE_VIRTUAL_DELETE
109 };
110 
111 };  // namespace tflite
112 
113 #endif  // TENSORFLOW_LITE_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
114