1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Mechanism to instantiate classes by name. 18 // 19 // This mechanism is useful if the concrete classes to be instantiated are not 20 // statically known (e.g., if their names are read from a dynamically-provided 21 // config). 22 // 23 // In that case, the first step is to define the API implemented by the 24 // instantiated classes. E.g., 25 // 26 // // In a header file function.h: 27 // 28 // // Abstract function that takes a double and returns a double. 29 // class Function : public RegisterableClass<Function> { 30 // public: 31 // virtual ~Function() {} 32 // virtual double Evaluate(double x) = 0; 33 // }; 34 // 35 // // Should be inside namespace libtextclassifier::nlp_core. 36 // TC_DECLARE_CLASS_REGISTRY_NAME(Function); 37 // 38 // Notice the inheritance from RegisterableClass<Function>. RegisterableClass 39 // is defined by this file (registry.h). Under the hood, this inheritanace 40 // defines a "registry" that maps names (zero-terminated arrays of chars) to 41 // factory methods that create Functions. You should give a human-readable name 42 // to this registry. To do that, use the following macro in a .cc file (it has 43 // to be a .cc file, as it defines some static data): 44 // 45 // // Inside function.cc 46 // // Should be inside namespace libtextclassifier::nlp_core. 47 // TC_DEFINE_CLASS_REGISTRY_NAME("function", Function); 48 // 49 // Now, let's define a few concrete Functions: e.g., 50 // 51 // class Cos : public Function { 52 // public: 53 // double Evaluate(double x) override { return cos(x); } 54 // TC_DEFINE_REGISTRATION_METHOD("cos", Cos); 55 // }; 56 // 57 // class Exp : public Function { 58 // public: 59 // double Evaluate(double x) override { return exp(x); } 60 // TC_DEFINE_REGISTRATION_METHOD("sin", Sin); 61 // }; 62 // 63 // Each concrete Function implementation should have (in the public section) the 64 // macro 65 // 66 // TC_DEFINE_REGISTRATION_METHOD("name", implementation_class); 67 // 68 // This defines a RegisterClass static method that, when invoked, associates 69 // "name" with a factory method that creates instances of implementation_class. 70 // 71 // Before instantiating Functions by name, we need to tell our system which 72 // Functions we may be interested in. This is done by calling the 73 // Foo::RegisterClass() for each relevant Foo implementation of Function. It is 74 // ok to call Foo::RegisterClass() multiple times (even in parallel): only the 75 // first call will perform something, the others will return immediately. 76 // 77 // Cos::RegisterClass(); 78 // Exp::RegisterClass(); 79 // 80 // Now, let's instantiate a Function based on its name. This get a lot more 81 // interesting if the Function name is not statically known (i.e., 82 // read from an input proto: 83 // 84 // std::unique_ptr<Function> f(Function::Create("cos")); 85 // double result = f->Evaluate(arg); 86 // 87 // NOTE: the same binary can use this mechanism for different APIs. E.g., one 88 // can also have (in the binary with Function, Sin, Cos, etc): 89 // 90 // class IntFunction : public RegisterableClass<IntFunction> { 91 // public: 92 // virtual ~IntFunction() {} 93 // virtual int Evaluate(int k) = 0; 94 // }; 95 // 96 // TC_DECLARE_CLASS_REGISTRY_NAME(IntFunction); 97 // 98 // TC_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction); 99 // 100 // class Inc : public IntFunction { 101 // public: 102 // int Evaluate(int k) override { return k + 1; } 103 // TC_DEFINE_REGISTRATION_METHOD("inc", Inc); 104 // }; 105 // 106 // RegisterableClass<Function> and RegisterableClass<IntFunction> define their 107 // own registries: each maps string names to implementation of the corresponding 108 // API. 109 110 #ifndef LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_ 111 #define LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_ 112 113 #include <stdlib.h> 114 #include <string.h> 115 116 #include <string> 117 118 #include "util/base/logging.h" 119 120 namespace libtextclassifier { 121 namespace nlp_core { 122 123 namespace internal { 124 // Registry that associates keys (zero-terminated array of chars) with values. 125 // Values are pointers to type T (the template parameter). This is used to 126 // store the association between component names and factory methods that 127 // produce those components; the error messages are focused on that case. 128 // 129 // Internally, this registry uses a linked list of (key, value) pairs. We do 130 // not use an STL map, list, etc because we aim for small code size. 131 template <class T> 132 class ComponentRegistry { 133 public: ComponentRegistry(const char * name)134 explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {} 135 136 // Adds a the (key, value) pair to this registry (if the key does not already 137 // exists in this registry) and returns true. If the registry already has a 138 // mapping for key, returns false and does not modify the registry. NOTE: the 139 // error (false) case happens even if the existing value for key is equal with 140 // the new one. 141 // 142 // This method does not take ownership of key, nor of value. Add(const char * key,T * value)143 bool Add(const char *key, T *value) { 144 const Cell *old_cell = FindCell(key); 145 if (old_cell != nullptr) { 146 TC_LOG(ERROR) << "Duplicate component: " << key; 147 return false; 148 } 149 Cell *new_cell = new Cell(key, value, head_); 150 head_ = new_cell; 151 return true; 152 } 153 154 // Returns the value attached to a key in this registry. Returns nullptr on 155 // error (e.g., unknown key). Lookup(const char * key)156 T *Lookup(const char *key) const { 157 const Cell *cell = FindCell(key); 158 if (cell == nullptr) { 159 TC_LOG(ERROR) << "Unknown " << name() << " component: " << key; 160 } 161 return (cell == nullptr) ? nullptr : cell->value(); 162 } 163 Lookup(const std::string & key)164 T *Lookup(const std::string &key) const { return Lookup(key.c_str()); } 165 166 // Returns name of this ComponentRegistry. name()167 const char *name() const { return name_; } 168 169 private: 170 // Cell for the singly-linked list underlying this ComponentRegistry. Each 171 // cell contains a key, the value for that key, as well as a pointer to the 172 // next Cell from the list. 173 class Cell { 174 public: 175 // Constructs a new Cell. Cell(const char * key,T * value,Cell * next)176 Cell(const char *key, T *value, Cell *next) 177 : key_(key), value_(value), next_(next) {} 178 key()179 const char *key() const { return key_; } value()180 T *value() const { return value_; } next()181 Cell *next() const { return next_; } 182 183 private: 184 const char *const key_; 185 T *const value_; 186 Cell *const next_; 187 }; 188 189 // Finds Cell for indicated key in the singly-linked list pointed to by head_. 190 // Returns pointer to that first Cell with that key, or nullptr if no such 191 // Cell (i.e., unknown key). 192 // 193 // Caller does NOT own the returned pointer. FindCell(const char * key)194 const Cell *FindCell(const char *key) const { 195 Cell *c = head_; 196 while (c != nullptr && strcmp(key, c->key()) != 0) { 197 c = c->next(); 198 } 199 return c; 200 } 201 202 // Human-readable description for this ComponentRegistry. For debug purposes. 203 const char *const name_; 204 205 // Pointer to the first Cell from the underlying list of (key, value) pairs. 206 Cell *head_; 207 }; 208 } // namespace internal 209 210 // Base class for registerable classes. 211 template <class T> 212 class RegisterableClass { 213 public: 214 // Factory function type. 215 typedef T *(Factory)(); 216 217 // Registry type. 218 typedef internal::ComponentRegistry<Factory> Registry; 219 220 // Creates a new instance of T. Returns pointer to new instance or nullptr in 221 // case of errors (e.g., unknown component). 222 // 223 // Passes ownership of the returned pointer to the caller. Create(const std::string & name)224 static T *Create(const std::string &name) { // NOLINT 225 auto *factory = registry()->Lookup(name); 226 if (factory == nullptr) { 227 TC_LOG(ERROR) << "Unknown RegisterableClass " << name; 228 return nullptr; 229 } 230 return factory(); 231 } 232 233 // Returns registry for class. registry()234 static Registry *registry() { 235 static Registry *registry_for_type_t = new Registry(kRegistryName); 236 return registry_for_type_t; 237 } 238 239 protected: 240 // Factory method for subclass ComponentClass. Used internally by the static 241 // method RegisterClass() defined by TC_DEFINE_REGISTRATION_METHOD. 242 template <class ComponentClass> _internal_component_factory()243 static T *_internal_component_factory() { 244 return new ComponentClass(); 245 } 246 247 private: 248 // Human-readable name for the registry for this class. 249 static const char kRegistryName[]; 250 }; 251 252 // Defines the static method component_class::RegisterClass() that should be 253 // called before trying to instantiate component_class by name. Should be used 254 // inside the public section of the declaration of component_class. See 255 // comments at the top-level of this file. 256 #define TC_DEFINE_REGISTRATION_METHOD(component_name, component_class) \ 257 static void RegisterClass() { \ 258 static bool once = registry()->Add( \ 259 component_name, &_internal_component_factory<component_class>); \ 260 if (!once) { \ 261 TC_LOG(ERROR) << "Problem registering " << component_name; \ 262 } \ 263 TC_DCHECK(once); \ 264 } 265 266 // Defines the human-readable name of the registry associated with base_class. 267 #define TC_DECLARE_CLASS_REGISTRY_NAME(base_class) \ 268 template <> \ 269 const char ::libtextclassifier::nlp_core::RegisterableClass< \ 270 base_class>::kRegistryName[] 271 272 // Defines the human-readable name of the registry associated with base_class. 273 #define TC_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \ 274 template <> \ 275 const char ::libtextclassifier::nlp_core::RegisterableClass< \ 276 base_class>::kRegistryName[] = registry_name 277 278 } // namespace nlp_core 279 } // namespace libtextclassifier 280 281 #endif // LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_ 282