1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Generic feature extractor for extracting features from objects. The feature 18 // extractor can be used for extracting features from any object. The feature 19 // extractor and feature function classes are template classes that have to 20 // be instantiated for extracting feature from a specific object type. 21 // 22 // A feature extractor consists of a hierarchy of feature functions. Each 23 // feature function extracts one or more feature type and value pairs from the 24 // object. 25 // 26 // The feature extractor has a modular design where new feature functions can be 27 // registered as components. The feature extractor is initialized from a 28 // descriptor represented by a protocol buffer. The feature extractor can also 29 // be initialized from a text-based source specification of the feature 30 // extractor. Feature specification parsers can be added as components. By 31 // default the feature extractor can be read from an ASCII protocol buffer or in 32 // a simple feature modeling language (fml). 33 34 // A feature function is invoked with a focus. Nested feature function can be 35 // invoked with another focus determined by the parent feature function. 36 37 #ifndef LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_ 38 #define LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_ 39 40 #include <stddef.h> 41 42 #include <string> 43 #include <vector> 44 45 #include "common/feature-descriptors.h" 46 #include "common/feature-types.h" 47 #include "common/fml-parser.h" 48 #include "common/registry.h" 49 #include "common/task-context.h" 50 #include "common/workspace.h" 51 #include "util/base/integral_types.h" 52 #include "util/base/logging.h" 53 #include "util/base/macros.h" 54 #include "util/gtl/stl_util.h" 55 56 namespace libtextclassifier { 57 namespace nlp_core { 58 59 typedef int64 Predicate; 60 typedef Predicate FeatureValue; 61 62 // A union used to represent discrete and continuous feature values. 63 union FloatFeatureValue { 64 public: FloatFeatureValue(FeatureValue v)65 explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {} FloatFeatureValue(uint32 i,float w)66 FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {} 67 FeatureValue discrete_value; 68 struct { 69 uint32 id; 70 float weight; 71 }; 72 }; 73 74 // A feature vector contains feature type and value pairs. 75 class FeatureVector { 76 public: FeatureVector()77 FeatureVector() {} 78 79 // Adds feature type and value pair to feature vector. add(FeatureType * type,FeatureValue value)80 void add(FeatureType *type, FeatureValue value) { 81 features_.emplace_back(type, value); 82 } 83 84 // Removes all elements from the feature vector. clear()85 void clear() { features_.clear(); } 86 87 // Returns the number of elements in the feature vector. size()88 int size() const { return features_.size(); } 89 90 // Reserves space in the underlying feature vector. reserve(int n)91 void reserve(int n) { features_.reserve(n); } 92 93 // Returns feature type for an element in the feature vector. type(int index)94 FeatureType *type(int index) const { return features_[index].type; } 95 96 // Returns feature value for an element in the feature vector. value(int index)97 FeatureValue value(int index) const { return features_[index].value; } 98 99 private: 100 // Structure for holding feature type and value pairs. 101 struct Element { ElementElement102 Element() : type(nullptr), value(-1) {} ElementElement103 Element(FeatureType *t, FeatureValue v) : type(t), value(v) {} 104 105 FeatureType *type; 106 FeatureValue value; 107 }; 108 109 // Array for storing feature vector elements. 110 std::vector<Element> features_; 111 112 TC_DISALLOW_COPY_AND_ASSIGN(FeatureVector); 113 }; 114 115 // The generic feature extractor is the type-independent part of a feature 116 // extractor. This holds the descriptor for the feature extractor and the 117 // collection of feature types used in the feature extractor. The feature 118 // types are not available until FeatureExtractor<>::Init() has been called. 119 class GenericFeatureExtractor { 120 public: 121 GenericFeatureExtractor(); 122 virtual ~GenericFeatureExtractor(); 123 124 // Initializes the feature extractor from an FML string specification. For 125 // the FML specification grammar, see fml-parser.h. 126 // 127 // Returns true on success, false on syntax error. 128 bool Parse(const std::string &source); 129 130 // Returns the feature extractor descriptor. descriptor()131 const FeatureExtractorDescriptor &descriptor() const { return descriptor_; } mutable_descriptor()132 FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; } 133 134 // Returns the number of feature types in the feature extractor. Invalid 135 // before Init() has been called. feature_types()136 int feature_types() const { return feature_types_.size(); } 137 138 // Returns a feature type used in the extractor. Invalid before Init() has 139 // been called. feature_type(int index)140 const FeatureType *feature_type(int index) const { 141 return feature_types_[index]; 142 } 143 144 // Returns the feature domain size of this feature extractor. 145 // NOTE: The way that domain size is calculated is, for some, unintuitive. It 146 // is the largest domain size of any feature type. 147 FeatureValue GetDomainSize() const; 148 149 protected: 150 // Initializes the feature types used by the extractor. Called from 151 // FeatureExtractor<>::Init(). 152 // 153 // Returns true on success, false on error. 154 bool InitializeFeatureTypes(); 155 156 private: 157 // Initializes the top-level feature functions. 158 virtual bool InitializeFeatureFunctions() = 0; 159 160 // Returns all feature types used by the extractor. The feature types are 161 // added to the result array. 162 virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0; 163 164 // Descriptor for the feature extractor. This is a protocol buffer that 165 // contains all the information about the feature extractor. The feature 166 // functions are initialized from the information in the descriptor. 167 FeatureExtractorDescriptor descriptor_; 168 169 // All feature types used by the feature extractor. The collection of all the 170 // feature types describes the feature space of the feature set produced by 171 // the feature extractor. Not owned. 172 std::vector<FeatureType *> feature_types_; 173 174 TC_DISALLOW_COPY_AND_ASSIGN(GenericFeatureExtractor); 175 }; 176 177 // The generic feature function is the type-independent part of a feature 178 // function. Each feature function is associated with the descriptor that it is 179 // instantiated from. The feature types associated with this feature function 180 // will be established by the time FeatureExtractor<>::Init() completes. 181 class GenericFeatureFunction { 182 public: 183 // A feature value that represents the absence of a value. 184 static constexpr FeatureValue kNone = -1; 185 186 GenericFeatureFunction(); 187 virtual ~GenericFeatureFunction(); 188 189 // Sets up the feature function. NB: FeatureTypes of nested functions are not 190 // guaranteed to be available until Init(). 191 // 192 // Returns true on success, false on error. Setup(TaskContext * context)193 virtual bool Setup(TaskContext *context) { return true; } 194 195 // Initializes the feature function. NB: The FeatureType of this function must 196 // be established when this method completes. 197 // 198 // Returns true on success, false on error. Init(TaskContext * context)199 virtual bool Init(TaskContext *context) { return true; } 200 201 // Requests workspaces from a registry to obtain indices into a WorkspaceSet 202 // for any Workspace objects used by this feature function. NB: This will be 203 // called after Init(), so it can depend on resources and arguments. RequestWorkspaces(WorkspaceRegistry * registry)204 virtual void RequestWorkspaces(WorkspaceRegistry *registry) {} 205 206 // Appends the feature types produced by the feature function to types. The 207 // default implementation appends feature_type(), if non-null. Invalid 208 // before Init() has been called. 209 virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const; 210 211 // Returns the feature type for feature produced by this feature function. If 212 // the feature function produces features of different types this returns 213 // null. Invalid before Init() has been called. 214 virtual FeatureType *GetFeatureType() const; 215 216 // Returns the name of the registry used for creating the feature function. 217 // This can be used for checking if two feature functions are of the same 218 // kind. 219 virtual const char *RegistryName() const = 0; 220 221 // Returns the value of a named parameter from the feature function 222 // descriptor. Returns empty string ("") if parameter is not found. 223 std::string GetParameter(const std::string &name) const; 224 225 // Returns the int value of a named parameter from the feature function 226 // descriptor. Returns default_value if the parameter is not found or if its 227 // value can't be parsed as an int. 228 int GetIntParameter(const std::string &name, int default_value) const; 229 230 // Returns the bool value of a named parameter from the feature function 231 // descriptor. Returns default_value if the parameter is not found or if its 232 // value is not "true" or "false". 233 bool GetBoolParameter(const std::string &name, bool default_value) const; 234 235 // Returns the FML function description for the feature function, i.e. the 236 // name and parameters without the nested features. FunctionName()237 std::string FunctionName() const { 238 std::string output; 239 ToFMLFunction(*descriptor_, &output); 240 return output; 241 } 242 243 // Returns the prefix for nested feature functions. This is the prefix of this 244 // feature function concatenated with the feature function name. SubPrefix()245 std::string SubPrefix() const { 246 return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName(); 247 } 248 249 // Returns/sets the feature extractor this function belongs to. extractor()250 GenericFeatureExtractor *extractor() const { return extractor_; } set_extractor(GenericFeatureExtractor * extractor)251 void set_extractor(GenericFeatureExtractor *extractor) { 252 extractor_ = extractor; 253 } 254 255 // Returns/sets the feature function descriptor. descriptor()256 FeatureFunctionDescriptor *descriptor() const { return descriptor_; } set_descriptor(FeatureFunctionDescriptor * descriptor)257 void set_descriptor(FeatureFunctionDescriptor *descriptor) { 258 descriptor_ = descriptor; 259 } 260 261 // Returns a descriptive name for the feature function. The name is taken from 262 // the descriptor for the feature function. If the name is empty or the 263 // feature function is a variable the name is the FML representation of the 264 // feature, including the prefix. 265 std::string name() const; 266 267 // Returns the argument from the feature function descriptor. It defaults to 268 // 0 if the argument has not been specified. argument()269 int argument() const { 270 return descriptor_->has_argument() ? descriptor_->argument() : 0; 271 } 272 273 // Returns/sets/clears function name prefix. prefix()274 const std::string &prefix() const { return prefix_; } set_prefix(const std::string & prefix)275 void set_prefix(const std::string &prefix) { prefix_ = prefix; } 276 277 protected: 278 // Returns the feature type for single-type feature functions. feature_type()279 FeatureType *feature_type() const { return feature_type_; } 280 281 // Sets the feature type for single-type feature functions. This takes 282 // ownership of feature_type. Can only be called once with a non-null 283 // pointer. set_feature_type(FeatureType * feature_type)284 void set_feature_type(FeatureType *feature_type) { 285 TC_DCHECK_NE(feature_type, nullptr); 286 feature_type_ = feature_type; 287 } 288 289 private: 290 // Feature extractor this feature function belongs to. Not owned. 291 GenericFeatureExtractor *extractor_ = nullptr; 292 293 // Descriptor for feature function. Not owned. 294 FeatureFunctionDescriptor *descriptor_ = nullptr; 295 296 // Feature type for features produced by this feature function. If the 297 // feature function produces features of multiple feature types this is null 298 // and the feature function must return it's feature types in 299 // GetFeatureTypes(). Owned. 300 FeatureType *feature_type_ = nullptr; 301 302 // Prefix used for sub-feature types of this function. 303 std::string prefix_; 304 }; 305 306 // Feature function that can extract features from an object. Templated on 307 // two type arguments: 308 // 309 // OBJ: The "object" from which features are extracted; e.g., a sentence. This 310 // should be a plain type, rather than a reference or pointer. 311 // 312 // ARGS: A set of 0 or more types that are used to "index" into some part of the 313 // object that should be extracted, e.g. an int token index for a sentence 314 // object. This should not be a reference type. 315 template <class OBJ, class... ARGS> 316 class FeatureFunction 317 : public GenericFeatureFunction, 318 public RegisterableClass<FeatureFunction<OBJ, ARGS...> > { 319 public: 320 using Self = FeatureFunction<OBJ, ARGS...>; 321 322 // Preprocesses the object. This will be called prior to calling Evaluate() 323 // or Compute() on that object. Preprocess(WorkspaceSet * workspaces,OBJ * object)324 virtual void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {} 325 326 // Appends features computed from the object and focus to the result. The 327 // default implementation delegates to Compute(), adding a single value if 328 // available. Multi-valued feature functions must override this method. Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)329 virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, 330 ARGS... args, FeatureVector *result) const { 331 FeatureValue value = Compute(workspaces, object, args..., result); 332 if (value != kNone) result->add(feature_type(), value); 333 } 334 335 // Returns a feature value computed from the object and focus, or kNone if no 336 // value is computed. Single-valued feature functions only need to override 337 // this method. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,const FeatureVector * fv)338 virtual FeatureValue Compute(const WorkspaceSet &workspaces, 339 const OBJ &object, ARGS... args, 340 const FeatureVector *fv) const { 341 return kNone; 342 } 343 344 // Instantiates a new feature function in a feature extractor from a feature 345 // descriptor. Instantiate(GenericFeatureExtractor * extractor,FeatureFunctionDescriptor * fd,const std::string & prefix)346 static Self *Instantiate(GenericFeatureExtractor *extractor, 347 FeatureFunctionDescriptor *fd, 348 const std::string &prefix) { 349 Self *f = Self::Create(fd->type()); 350 if (f != nullptr) { 351 f->set_extractor(extractor); 352 f->set_descriptor(fd); 353 f->set_prefix(prefix); 354 } 355 return f; 356 } 357 358 // Returns the name of the registry for the feature function. RegistryName()359 const char *RegistryName() const override { return Self::registry()->name(); } 360 361 private: 362 // Special feature function class for resolving variable references. The type 363 // of the feature function is used for resolving the variable reference. When 364 // evaluated it will either get the feature value(s) from the variable portion 365 // of the feature vector, if present, or otherwise it will call the referenced 366 // feature extractor function directly to extract the feature(s). 367 class Reference; 368 }; 369 370 // Base class for features with nested feature functions. The nested functions 371 // are of type NES, which may be different from the type of the parent function. 372 // NB: NestedFeatureFunction will ensure that all initialization of nested 373 // functions takes place during Setup() and Init() -- after the nested features 374 // are initialized, the parent feature is initialized via SetupNested() and 375 // InitNested(). Alternatively, a derived classes that overrides Setup() and 376 // Init() directly should call Parent::Setup(), Parent::Init(), etc. first. 377 // 378 // Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or 379 // Compute, since the nested functions may be of a different type. 380 template <class NES, class OBJ, class... ARGS> 381 class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> { 382 public: 383 using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>; 384 385 // Clean up nested functions. ~NestedFeatureFunction()386 ~NestedFeatureFunction() override { 387 // Fully qualified class name, to avoid an ambiguity error when building for 388 // Android. 389 ::libtextclassifier::STLDeleteElements(&nested_); 390 } 391 392 // By default, just appends the nested feature types. GetFeatureTypes(std::vector<FeatureType * > * types)393 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 394 // It's odd if a NestedFeatureFunction does not have anything nested inside 395 // it, so we crash in debug mode. Still, nothing should crash in prod mode. 396 TC_DCHECK(!this->nested().empty()) 397 << "Nested features require nested features to be defined."; 398 for (auto *function : nested_) function->GetFeatureTypes(types); 399 } 400 401 // Sets up the nested features. Setup(TaskContext * context)402 bool Setup(TaskContext *context) override { 403 bool success = CreateNested(this->extractor(), this->descriptor(), &nested_, 404 this->SubPrefix()); 405 if (!success) { 406 return false; 407 } 408 for (auto *function : nested_) { 409 if (!function->Setup(context)) return false; 410 } 411 if (!SetupNested(context)) { 412 return false; 413 } 414 return true; 415 } 416 417 // Sets up this NestedFeatureFunction specifically. SetupNested(TaskContext * context)418 virtual bool SetupNested(TaskContext *context) { return true; } 419 420 // Initializes the nested features. Init(TaskContext * context)421 bool Init(TaskContext *context) override { 422 for (auto *function : nested_) { 423 if (!function->Init(context)) return false; 424 } 425 if (!InitNested(context)) return false; 426 return true; 427 } 428 429 // Initializes this NestedFeatureFunction specifically. InitNested(TaskContext * context)430 virtual bool InitNested(TaskContext *context) { return true; } 431 432 // Gets all the workspaces needed for the nested functions. RequestWorkspaces(WorkspaceRegistry * registry)433 void RequestWorkspaces(WorkspaceRegistry *registry) override { 434 for (auto *function : nested_) function->RequestWorkspaces(registry); 435 } 436 437 // Returns the list of nested feature functions. nested()438 const std::vector<NES *> &nested() const { return nested_; } 439 440 // Instantiates nested feature functions for a feature function. Creates and 441 // initializes one feature function for each sub-descriptor in the feature 442 // descriptor. CreateNested(GenericFeatureExtractor * extractor,FeatureFunctionDescriptor * fd,std::vector<NES * > * functions,const std::string & prefix)443 static bool CreateNested(GenericFeatureExtractor *extractor, 444 FeatureFunctionDescriptor *fd, 445 std::vector<NES *> *functions, 446 const std::string &prefix) { 447 for (int i = 0; i < fd->feature_size(); ++i) { 448 FeatureFunctionDescriptor *sub = fd->mutable_feature(i); 449 NES *f = NES::Instantiate(extractor, sub, prefix); 450 if (f == nullptr) { 451 return false; 452 } 453 functions->push_back(f); 454 } 455 return true; 456 } 457 458 protected: 459 // The nested feature functions, if any, in order of declaration in the 460 // feature descriptor. Owned. 461 std::vector<NES *> nested_; 462 }; 463 464 // Base class for a nested feature function that takes nested features with the 465 // same signature as these features, i.e. a meta feature. For this class, we can 466 // provide preprocessing of the nested features. 467 template <class OBJ, class... ARGS> 468 class MetaFeatureFunction 469 : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ, 470 ARGS...> { 471 public: 472 // Preprocesses using the nested features. Preprocess(WorkspaceSet * workspaces,OBJ * object)473 void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override { 474 for (auto *function : this->nested_) { 475 function->Preprocess(workspaces, object); 476 } 477 } 478 }; 479 480 // Template for a special type of locator: The locator of type 481 // FeatureFunction<OBJ, ARGS...> calls nested functions of type 482 // FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is 483 // responsible for translating by providing the following: 484 // 485 // // Gets the new additional focus. 486 // IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object); 487 // 488 // This is useful to e.g. add a token focus to a parser state based on some 489 // desired property of that state. 490 template <class DER, class OBJ, class IDX, class... ARGS> 491 class FeatureAddFocusLocator 492 : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ, 493 ARGS...> { 494 public: Preprocess(WorkspaceSet * workspaces,OBJ * object)495 void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override { 496 for (auto *function : this->nested_) { 497 function->Preprocess(workspaces, object); 498 } 499 } 500 Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)501 void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, 502 FeatureVector *result) const override { 503 IDX focus = 504 static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); 505 for (auto *function : this->nested()) { 506 function->Evaluate(workspaces, object, focus, args..., result); 507 } 508 } 509 510 // Returns the first nested feature's computed value. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,const FeatureVector * result)511 FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, 512 ARGS... args, 513 const FeatureVector *result) const override { 514 IDX focus = 515 static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); 516 return this->nested()[0]->Compute(workspaces, object, focus, args..., 517 result); 518 } 519 }; 520 521 // CRTP feature locator class. This is a meta feature that modifies ARGS and 522 // then calls the nested feature functions with the modified ARGS. Note that in 523 // order for this template to work correctly, all of ARGS must be types for 524 // which the reference operator & can be interpreted as a pointer to the 525 // argument. The derived class DER must implement the UpdateFocus method which 526 // takes pointers to the ARGS arguments: 527 // 528 // // Updates the current arguments. 529 // void UpdateArgs(const OBJ &object, ARGS *...args) const; 530 template <class DER, class OBJ, class... ARGS> 531 class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> { 532 public: 533 // Feature locators have an additional check that there is no intrinsic type, 534 // but only in debug mode: having an intrinsic type here is odd, but not 535 // enough to motive a crash in prod. GetFeatureTypes(std::vector<FeatureType * > * types)536 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 537 TC_DCHECK_EQ(this->feature_type(), nullptr) 538 << "FeatureLocators should not have an intrinsic type."; 539 MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types); 540 } 541 542 // Evaluates the locator. Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)543 void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, 544 FeatureVector *result) const override { 545 static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); 546 for (auto *function : this->nested()) { 547 function->Evaluate(workspaces, object, args..., result); 548 } 549 } 550 551 // Returns the first nested feature's computed value. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,const FeatureVector * result)552 FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, 553 ARGS... args, 554 const FeatureVector *result) const override { 555 static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); 556 return this->nested()[0]->Compute(workspaces, object, args..., result); 557 } 558 }; 559 560 // Feature extractor for extracting features from objects of a certain class. 561 // Template type parameters are as defined for FeatureFunction. 562 template <class OBJ, class... ARGS> 563 class FeatureExtractor : public GenericFeatureExtractor { 564 public: 565 // Feature function type for top-level functions in the feature extractor. 566 typedef FeatureFunction<OBJ, ARGS...> Function; 567 typedef FeatureExtractor<OBJ, ARGS...> Self; 568 569 // Feature locator type for the feature extractor. 570 template <class DER> 571 using Locator = FeatureLocator<DER, OBJ, ARGS...>; 572 573 // Initializes feature extractor. FeatureExtractor()574 FeatureExtractor() {} 575 ~FeatureExtractor()576 ~FeatureExtractor() override { 577 // Fully qualified class name, to avoid an ambiguity error when building for 578 // Android. 579 ::libtextclassifier::STLDeleteElements(&functions_); 580 } 581 582 // Sets up the feature extractor. Note that only top-level functions exist 583 // until Setup() is called. This does not take ownership over the context, 584 // which must outlive this. Setup(TaskContext * context)585 bool Setup(TaskContext *context) { 586 for (Function *function : functions_) { 587 if (!function->Setup(context)) return false; 588 } 589 return true; 590 } 591 592 // Initializes the feature extractor. Must be called after Setup(). This 593 // does not take ownership over the context, which must outlive this. Init(TaskContext * context)594 bool Init(TaskContext *context) { 595 for (Function *function : functions_) { 596 if (!function->Init(context)) return false; 597 } 598 if (!this->InitializeFeatureTypes()) { 599 return false; 600 } 601 return true; 602 } 603 604 // Requests workspaces from the registry. Must be called after Init(), and 605 // before Preprocess(). Does not take ownership over registry. This should be 606 // the same registry used to initialize the WorkspaceSet used in Preprocess() 607 // and ExtractFeatures(). NB: This is a different ordering from that used in 608 // SentenceFeatureRepresentation style feature computation. RequestWorkspaces(WorkspaceRegistry * registry)609 void RequestWorkspaces(WorkspaceRegistry *registry) { 610 for (auto *function : functions_) function->RequestWorkspaces(registry); 611 } 612 613 // Preprocesses the object using feature functions for the phase. Must be 614 // called before any calls to ExtractFeatures() on that object and phase. Preprocess(WorkspaceSet * workspaces,OBJ * object)615 void Preprocess(WorkspaceSet *workspaces, OBJ *object) const { 616 for (Function *function : functions_) { 617 function->Preprocess(workspaces, object); 618 } 619 } 620 621 // Extracts features from an object with a focus. This invokes all the 622 // top-level feature functions in the feature extractor. Only feature 623 // functions belonging to the specified phase are invoked. ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)624 void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object, 625 ARGS... args, FeatureVector *result) const { 626 result->reserve(this->feature_types()); 627 628 // Extract features. 629 for (int i = 0; i < functions_.size(); ++i) { 630 functions_[i]->Evaluate(workspaces, object, args..., result); 631 } 632 } 633 634 private: 635 // Creates and initializes all feature functions in the feature extractor. InitializeFeatureFunctions()636 bool InitializeFeatureFunctions() override { 637 // Create all top-level feature functions. 638 for (int i = 0; i < descriptor().feature_size(); ++i) { 639 FeatureFunctionDescriptor *fd = mutable_descriptor()->mutable_feature(i); 640 Function *function = Function::Instantiate(this, fd, ""); 641 if (function == nullptr) return false; 642 functions_.push_back(function); 643 } 644 return true; 645 } 646 647 // Collect all feature types used in the feature extractor. GetFeatureTypes(std::vector<FeatureType * > * types)648 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 649 for (Function *function : functions_) { 650 function->GetFeatureTypes(types); 651 } 652 } 653 654 // Top-level feature functions (and variables) in the feature extractor. 655 // Owned. INVARIANT: contains only non-null pointers. 656 std::vector<Function *> functions_; 657 }; 658 659 #define REGISTER_FEATURE_FUNCTION(base, name, component) \ 660 REGISTER_CLASS_COMPONENT(base, name, component) 661 662 } // namespace nlp_core 663 } // namespace libtextclassifier 664 665 #endif // LIBTEXTCLASSIFIER_COMMON_FEATURE_EXTRACTOR_H_ 666