1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Generic feature extractor for extracting features from objects. The feature 18 // extractor can be used for extracting features from any object. The feature 19 // extractor and feature function classes are template classes that have to 20 // be instantiated for extracting feature from a specific object type. 21 // 22 // A feature extractor consists of a hierarchy of feature functions. Each 23 // feature function extracts one or more feature type and value pairs from the 24 // object. 25 // 26 // The feature extractor has a modular design where new feature functions can be 27 // registered as components. The feature extractor is initialized from a 28 // descriptor represented by a protocol buffer. The feature extractor can also 29 // be initialized from a text-based source specification of the feature 30 // extractor. Feature specification parsers can be added as components. By 31 // default the feature extractor can be read from an ASCII protocol buffer or in 32 // a simple feature modeling language (fml). 33 34 // A feature function is invoked with a focus. Nested feature function can be 35 // invoked with another focus determined by the parent feature function. 36 37 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_ 38 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_ 39 40 #include <stddef.h> 41 42 #include <string> 43 #include <vector> 44 45 #include "lang_id/common/fel/feature-descriptors.h" 46 #include "lang_id/common/fel/feature-types.h" 47 #include "lang_id/common/fel/task-context.h" 48 #include "lang_id/common/fel/workspace.h" 49 #include "lang_id/common/lite_base/attributes.h" 50 #include "lang_id/common/lite_base/integral-types.h" 51 #include "lang_id/common/lite_base/logging.h" 52 #include "lang_id/common/lite_base/macros.h" 53 #include "lang_id/common/registry.h" 54 #include "lang_id/common/stl-util.h" 55 56 namespace libtextclassifier3 { 57 namespace mobile { 58 59 // TODO(djweiss) Clean this up as well. 60 // Use the same type for feature values as is used for predicated. 61 typedef int64 Predicate; 62 typedef Predicate FeatureValue; 63 64 // A union used to represent discrete and continuous feature values. 65 union FloatFeatureValue { 66 public: FloatFeatureValue(FeatureValue v)67 explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {} FloatFeatureValue(uint32 i,float w)68 FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {} 69 FeatureValue discrete_value; 70 struct { 71 uint32 id; 72 float weight; 73 }; 74 }; 75 76 // A feature vector contains feature type and value pairs. 77 class FeatureVector { 78 public: FeatureVector()79 FeatureVector() {} 80 81 // Adds feature type and value pair to feature vector. add(FeatureType * type,FeatureValue value)82 void add(FeatureType *type, FeatureValue value) { 83 features_.emplace_back(type, value); 84 } 85 86 // Removes all elements from the feature vector. clear()87 void clear() { features_.clear(); } 88 89 // Returns the number of elements in the feature vector. size()90 int size() const { return features_.size(); } 91 92 // Reserves space in the underlying feature vector. reserve(int n)93 void reserve(int n) { features_.reserve(n); } 94 95 // Returns feature type for an element in the feature vector. type(int index)96 FeatureType *type(int index) const { return features_[index].type; } 97 98 // Returns feature value for an element in the feature vector. value(int index)99 FeatureValue value(int index) const { return features_[index].value; } 100 101 private: 102 // Structure for holding feature type and value pairs. 103 struct Element { ElementElement104 Element() : type(nullptr), value(-1) {} ElementElement105 Element(FeatureType *t, FeatureValue v) : type(t), value(v) {} 106 107 FeatureType *type; 108 FeatureValue value; 109 }; 110 111 // Array for storing feature vector elements. 112 std::vector<Element> features_; 113 114 SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureVector); 115 }; 116 117 // The generic feature extractor is the type-independent part of a feature 118 // extractor. This holds the descriptor for the feature extractor and the 119 // collection of feature types used in the feature extractor. The feature 120 // types are not available until FeatureExtractor<>::Init() has been called. 121 class GenericFeatureExtractor { 122 public: 123 GenericFeatureExtractor(); 124 virtual ~GenericFeatureExtractor(); 125 126 // Initializes the feature extractor from the FEL specification |source|. 127 // 128 // Returns true on success, false otherwise (e.g., FEL syntax error). 129 SAFTM_MUST_USE_RESULT bool Parse(const string &source); 130 131 // Returns the feature extractor descriptor. descriptor()132 const FeatureExtractorDescriptor &descriptor() const { return descriptor_; } mutable_descriptor()133 FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; } 134 135 // Returns the number of feature types in the feature extractor. Invalid 136 // before Init() has been called. feature_types()137 int feature_types() const { return feature_types_.size(); } 138 139 protected: 140 // Initializes the feature types used by the extractor. Called from 141 // FeatureExtractor<>::Init(). 142 // 143 // Returns true on success, false otherwise. 144 SAFTM_MUST_USE_RESULT bool InitializeFeatureTypes(); 145 146 private: 147 // Initializes the top-level feature functions. 148 // 149 // Returns true on success, false otherwise. 150 SAFTM_MUST_USE_RESULT virtual bool InitializeFeatureFunctions() = 0; 151 152 // Returns all feature types used by the extractor. The feature types are 153 // added to the result array. 154 virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0; 155 156 // Descriptor for the feature extractor. This is a protocol buffer that 157 // contains all the information about the feature extractor. The feature 158 // functions are initialized from the information in the descriptor. 159 FeatureExtractorDescriptor descriptor_; 160 161 // All feature types used by the feature extractor. The collection of all the 162 // feature types describes the feature space of the feature set produced by 163 // the feature extractor. Not owned. 164 std::vector<FeatureType *> feature_types_; 165 }; 166 167 // The generic feature function is the type-independent part of a feature 168 // function. Each feature function is associated with the descriptor that it is 169 // instantiated from. The feature types associated with this feature function 170 // will be established by the time FeatureExtractor<>::Init() completes. 171 class GenericFeatureFunction { 172 public: 173 // A feature value that represents the absence of a value. 174 static constexpr FeatureValue kNone = -1; 175 176 GenericFeatureFunction(); 177 virtual ~GenericFeatureFunction(); 178 179 // Sets up the feature function. NB: FeatureTypes of nested functions are not 180 // guaranteed to be available until Init(). 181 // 182 // Returns true on success, false otherwise. Setup(TaskContext * context)183 SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context) { 184 return true; 185 } 186 187 // Initializes the feature function. NB: The FeatureType of this function must 188 // be established when this method completes. 189 // 190 // Returns true on success, false otherwise. Init(TaskContext * context)191 SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context) { return true; } 192 193 // Requests workspaces from a registry to obtain indices into a WorkspaceSet 194 // for any Workspace objects used by this feature function. NB: This will be 195 // called after Init(), so it can depend on resources and arguments. RequestWorkspaces(WorkspaceRegistry * registry)196 virtual void RequestWorkspaces(WorkspaceRegistry *registry) {} 197 198 // Appends the feature types produced by the feature function to types. The 199 // default implementation appends feature_type(), if non-null. Invalid 200 // before Init() has been called. 201 virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const; 202 203 // Returns the feature type for feature produced by this feature function. If 204 // the feature function produces features of different types this returns 205 // null. Invalid before Init() has been called. 206 virtual FeatureType *GetFeatureType() const; 207 208 // Returns value of parameter |name| from the feature function descriptor. 209 // If the parameter is not present, returns the indicated |default_value|. 210 string GetParameter(const string &name, const string &default_value) const; 211 212 // Returns value of int parameter |name| from feature function descriptor. 213 // If the parameter is not present, or its value can't be parsed as an int, 214 // returns |default_value|. 215 int GetIntParameter(const string &name, int default_value) const; 216 217 // Returns value of bool parameter |name| from feature function descriptor. 218 // If the parameter is not present, or its value is not "true" or "false", 219 // returns |default_value|. NOTE: this method is case sensitive, it doesn't 220 // do any lower-casing. 221 bool GetBoolParameter(const string &name, bool default_value) const; 222 223 // Returns the FEL function description for the feature function, i.e. the 224 // name and parameters without the nested features. FunctionName()225 string FunctionName() const { 226 string output; 227 ToFELFunction(*descriptor_, &output); 228 return output; 229 } 230 231 // Returns the prefix for nested feature functions. This is the prefix of this 232 // feature function concatenated with the feature function name. SubPrefix()233 string SubPrefix() const { 234 return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName(); 235 } 236 237 // Returns/sets the feature extractor this function belongs to. extractor()238 const GenericFeatureExtractor *extractor() const { return extractor_; } set_extractor(const GenericFeatureExtractor * extractor)239 void set_extractor(const GenericFeatureExtractor *extractor) { 240 extractor_ = extractor; 241 } 242 243 // Returns/sets the feature function descriptor. descriptor()244 const FeatureFunctionDescriptor *descriptor() const { return descriptor_; } set_descriptor(const FeatureFunctionDescriptor * descriptor)245 void set_descriptor(const FeatureFunctionDescriptor *descriptor) { 246 descriptor_ = descriptor; 247 } 248 249 // Returns a descriptive name for the feature function. The name is taken from 250 // the descriptor for the feature function. If the name is empty or the 251 // feature function is a variable the name is the FEL representation of the 252 // feature, including the prefix. 253 string name() const; 254 255 // Returns the argument from the feature function descriptor. It defaults to 256 // 0 if the argument has not been specified. argument()257 int argument() const { 258 return descriptor_->has_argument() ? descriptor_->argument() : 0; 259 } 260 261 // Returns/sets/clears function name prefix. prefix()262 const string &prefix() const { return prefix_; } set_prefix(const string & prefix)263 void set_prefix(const string &prefix) { prefix_ = prefix; } 264 265 protected: 266 // Returns the feature type for single-type feature functions. feature_type()267 FeatureType *feature_type() const { return feature_type_; } 268 269 // Sets the feature type for single-type feature functions. This takes 270 // ownership of feature_type. Can only be called once. set_feature_type(FeatureType * feature_type)271 void set_feature_type(FeatureType *feature_type) { 272 SAFTM_CHECK_EQ(feature_type_, nullptr); 273 feature_type_ = feature_type; 274 } 275 276 private: 277 // Feature extractor this feature function belongs to. Not owned. Set to a 278 // pointer != nullptr as soon as this object is created by Instantiate(). 279 // Normal methods can safely assume this is != nullptr. 280 const GenericFeatureExtractor *extractor_ = nullptr; 281 282 // Descriptor for feature function. Not owned. Set to a pointer != nullptr 283 // as soon as this object is created by Instantiate(). Normal methods can 284 // safely assume this is != nullptr. 285 const FeatureFunctionDescriptor *descriptor_ = nullptr; 286 287 // Feature type for features produced by this feature function. If the 288 // feature function produces features of multiple feature types this is null 289 // and the feature function must return it's feature types in 290 // GetFeatureTypes(). Owned. 291 FeatureType *feature_type_ = nullptr; 292 293 // Prefix used for sub-feature types of this function. 294 string prefix_; 295 }; 296 297 // Feature function that can extract features from an object. Templated on 298 // two type arguments: 299 // 300 // OBJ: The "object" from which features are extracted; e.g., a sentence. This 301 // should be a plain type, rather than a reference or pointer. 302 // 303 // ARGS: A set of 0 or more types that are used to "index" into some part of the 304 // object that should be extracted, e.g. an int token index for a sentence 305 // object. This should not be a reference type. 306 template <class OBJ, class... ARGS> 307 class FeatureFunction 308 : public GenericFeatureFunction, 309 public RegisterableClass<FeatureFunction<OBJ, ARGS...> > { 310 public: 311 using Self = FeatureFunction<OBJ, ARGS...>; 312 313 // Preprocesses the object. This will be called prior to calling Evaluate() 314 // or Compute() on that object. Preprocess(WorkspaceSet * workspaces,const OBJ * object)315 virtual void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {} 316 317 // Appends features computed from the object and focus to the result. The 318 // default implementation delegates to Compute(), adding a single value if 319 // available. Multi-valued feature functions must override this method. Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)320 virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, 321 ARGS... args, FeatureVector *result) const { 322 FeatureValue value = Compute(workspaces, object, args...); 323 if (value != kNone) result->add(feature_type(), value); 324 } 325 326 // Returns a feature value computed from the object and focus, or kNone if no 327 // value is computed. Single-valued feature functions only need to override 328 // this method. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args)329 virtual FeatureValue Compute(const WorkspaceSet &workspaces, 330 const OBJ &object, ARGS... args) const { 331 return kNone; 332 } 333 334 // Instantiates a new feature function in a feature extractor from a feature 335 // descriptor. 336 // 337 // Returns a pointer to the newly-created object if everything goes well. 338 // Returns nullptr if the feature function could not be instantiated (e.g., if 339 // the function with that name is not registered; this usually happens because 340 // the relevant cc_library was not linked-in). Instantiate(const GenericFeatureExtractor * extractor,const FeatureFunctionDescriptor * fd,const string & prefix)341 static Self *Instantiate(const GenericFeatureExtractor *extractor, 342 const FeatureFunctionDescriptor *fd, 343 const string &prefix) { 344 Self *f = Self::Create(fd->type()); 345 if (f != nullptr) { 346 f->set_extractor(extractor); 347 f->set_descriptor(fd); 348 f->set_prefix(prefix); 349 } 350 return f; 351 } 352 353 private: 354 // Special feature function class for resolving variable references. The type 355 // of the feature function is used for resolving the variable reference. When 356 // evaluated it will either get the feature value(s) from the variable portion 357 // of the feature vector, if present, or otherwise it will call the referenced 358 // feature extractor function directly to extract the feature(s). 359 class Reference; 360 }; 361 362 // Base class for features with nested feature functions. The nested functions 363 // are of type NES, which may be different from the type of the parent function. 364 // NB: NestedFeatureFunction will ensure that all initialization of nested 365 // functions takes place during Setup() and Init() -- after the nested features 366 // are initialized, the parent feature is initialized via SetupNested() and 367 // InitNested(). Alternatively, a derived classes that overrides Setup() and 368 // Init() directly should call Parent::Setup(), Parent::Init(), etc. first. 369 // 370 // Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or 371 // Compute, since the nested functions may be of a different type. 372 template <class NES, class OBJ, class... ARGS> 373 class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> { 374 public: 375 using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>; 376 377 // Clean up nested functions. ~NestedFeatureFunction()378 ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); } 379 380 // By default, just appends the nested feature types. GetFeatureTypes(std::vector<FeatureType * > * types)381 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 382 SAFTM_CHECK(!this->nested().empty()) 383 << "Nested features require nested features to be defined."; 384 for (auto *function : nested_) function->GetFeatureTypes(types); 385 } 386 387 // Sets up the nested features. 388 // 389 // Returns true on success, false otherwise. Setup(TaskContext * context)390 SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override { 391 bool success = CreateNested(this->extractor(), this->descriptor(), &nested_, 392 this->SubPrefix()); 393 if (!success) return false; 394 for (auto *function : nested_) { 395 if (!function->Setup(context)) return false; 396 } 397 if (!SetupNested(context)) return false; 398 return true; 399 } 400 401 // Sets up this NestedFeatureFunction specifically. 402 // 403 // Returns true on success, false otherwise. SetupNested(TaskContext * context)404 SAFTM_MUST_USE_RESULT virtual bool SetupNested(TaskContext *context) { 405 return true; 406 } 407 408 // Initializes the nested features. 409 // 410 // Returns true on success, false otherwise. Init(TaskContext * context)411 SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override { 412 for (auto *function : nested_) { 413 if (!function->Init(context)) return false; 414 } 415 if (!InitNested(context)) return false; 416 return true; 417 } 418 419 // Initializes this NestedFeatureFunction specifically. 420 // 421 // Returns true on success, false otherwise. InitNested(TaskContext * context)422 SAFTM_MUST_USE_RESULT virtual bool InitNested(TaskContext *context) { 423 return true; 424 } 425 426 // Gets all the workspaces needed for the nested functions. RequestWorkspaces(WorkspaceRegistry * registry)427 void RequestWorkspaces(WorkspaceRegistry *registry) override { 428 for (auto *function : nested_) function->RequestWorkspaces(registry); 429 } 430 431 // Returns the list of nested feature functions. nested()432 const std::vector<NES *> &nested() const { return nested_; } 433 434 // Instantiates nested feature functions for a feature function. Creates and 435 // initializes one feature function for each sub-descriptor in the feature 436 // descriptor. 437 // 438 // Returns true on success, false otherwise. CreateNested(const GenericFeatureExtractor * extractor,const FeatureFunctionDescriptor * fd,std::vector<NES * > * functions,const string & prefix)439 SAFTM_MUST_USE_RESULT static bool CreateNested( 440 const GenericFeatureExtractor *extractor, 441 const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions, 442 const string &prefix) { 443 for (int i = 0; i < fd->feature_size(); ++i) { 444 const FeatureFunctionDescriptor &sub = fd->feature(i); 445 NES *f = NES::Instantiate(extractor, &sub, prefix); 446 if (f == nullptr) return false; 447 functions->push_back(f); 448 } 449 return true; 450 } 451 452 protected: 453 // The nested feature functions, if any, in order of declaration in the 454 // feature descriptor. Owned. 455 std::vector<NES *> nested_; 456 }; 457 458 // Base class for a nested feature function that takes nested features with the 459 // same signature as these features, i.e. a meta feature. For this class, we can 460 // provide preprocessing of the nested features. 461 template <class OBJ, class... ARGS> 462 class MetaFeatureFunction 463 : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ, 464 ARGS...> { 465 public: 466 // Preprocesses using the nested features. Preprocess(WorkspaceSet * workspaces,const OBJ * object)467 void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override { 468 for (auto *function : this->nested_) { 469 function->Preprocess(workspaces, object); 470 } 471 } 472 }; 473 474 // Template for a special type of locator: The locator of type 475 // FeatureFunction<OBJ, ARGS...> calls nested functions of type 476 // FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is 477 // responsible for translating by providing the following: 478 // 479 // // Gets the new additional focus. 480 // IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object); 481 // 482 // This is useful to e.g. add a token focus to a parser state based on some 483 // desired property of that state. 484 template <class DER, class OBJ, class IDX, class... ARGS> 485 class FeatureAddFocusLocator 486 : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ, 487 ARGS...> { 488 public: Preprocess(WorkspaceSet * workspaces,const OBJ * object)489 void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override { 490 for (auto *function : this->nested_) { 491 function->Preprocess(workspaces, object); 492 } 493 } 494 Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)495 void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, 496 FeatureVector *result) const override { 497 IDX focus = 498 static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); 499 for (auto *function : this->nested()) { 500 function->Evaluate(workspaces, object, focus, args..., result); 501 } 502 } 503 504 // Returns the first nested feature's computed value. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args)505 FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, 506 ARGS... args) const override { 507 IDX focus = 508 static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); 509 return this->nested()[0]->Compute(workspaces, object, focus, args...); 510 } 511 }; 512 513 // CRTP feature locator class. This is a meta feature that modifies ARGS and 514 // then calls the nested feature functions with the modified ARGS. Note that in 515 // order for this template to work correctly, all of ARGS must be types for 516 // which the reference operator & can be interpreted as a pointer to the 517 // argument. The derived class DER must implement the UpdateFocus method which 518 // takes pointers to the ARGS arguments: 519 // 520 // // Updates the current arguments. 521 // void UpdateArgs(const OBJ &object, ARGS *...args) const; 522 template <class DER, class OBJ, class... ARGS> 523 class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> { 524 public: 525 // Feature locators have an additional check that there is no intrinsic type. GetFeatureTypes(std::vector<FeatureType * > * types)526 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 527 SAFTM_CHECK_EQ(this->feature_type(), nullptr) 528 << "FeatureLocators should not have an intrinsic type."; 529 MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types); 530 } 531 532 // Evaluates the locator. Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)533 void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, 534 FeatureVector *result) const override { 535 static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); 536 for (auto *function : this->nested()) { 537 function->Evaluate(workspaces, object, args..., result); 538 } 539 } 540 541 // Returns the first nested feature's computed value. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args)542 FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, 543 ARGS... args) const override { 544 static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); 545 return this->nested()[0]->Compute(workspaces, object, args...); 546 } 547 }; 548 549 // Feature extractor for extracting features from objects of a certain class. 550 // Template type parameters are as defined for FeatureFunction. 551 template <class OBJ, class... ARGS> 552 class FeatureExtractor : public GenericFeatureExtractor { 553 public: 554 // Feature function type for top-level functions in the feature extractor. 555 typedef FeatureFunction<OBJ, ARGS...> Function; 556 typedef FeatureExtractor<OBJ, ARGS...> Self; 557 558 // Feature locator type for the feature extractor. 559 template <class DER> 560 using Locator = FeatureLocator<DER, OBJ, ARGS...>; 561 562 // Initializes feature extractor. FeatureExtractor()563 FeatureExtractor() {} 564 ~FeatureExtractor()565 ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); } 566 567 // Sets up the feature extractor. Note that only top-level functions exist 568 // until Setup() is called. This does not take ownership over the context, 569 // which must outlive this. 570 // 571 // Returns true on success, false otherwise. Setup(TaskContext * context)572 SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) { 573 for (Function *function : functions_) { 574 if (!function->Setup(context)) return false; 575 } 576 return true; 577 } 578 579 // Initializes the feature extractor. Must be called after Setup(). This 580 // does not take ownership over the context, which must outlive this. 581 // 582 // Returns true on success, false otherwise. Init(TaskContext * context)583 SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) { 584 for (Function *function : functions_) { 585 if (!function->Init(context)) return false; 586 } 587 if (!this->InitializeFeatureTypes()) return false; 588 return true; 589 } 590 591 // Requests workspaces from the registry. Must be called after Init(), and 592 // before Preprocess(). Does not take ownership over registry. This should be 593 // the same registry used to initialize the WorkspaceSet used in Preprocess() 594 // and ExtractFeatures(). NB: This is a different ordering from that used in 595 // SentenceFeatureRepresentation style feature computation. RequestWorkspaces(WorkspaceRegistry * registry)596 void RequestWorkspaces(WorkspaceRegistry *registry) { 597 for (auto *function : functions_) function->RequestWorkspaces(registry); 598 } 599 600 // Preprocesses the object using feature functions for the phase. Must be 601 // called before any calls to ExtractFeatures() on that object and phase. Preprocess(WorkspaceSet * workspaces,const OBJ * object)602 void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const { 603 for (Function *function : functions_) { 604 function->Preprocess(workspaces, object); 605 } 606 } 607 608 // Extracts features from an object with a focus. This invokes all the 609 // top-level feature functions in the feature extractor. Only feature 610 // functions belonging to the specified phase are invoked. ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)611 void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object, 612 ARGS... args, FeatureVector *result) const { 613 result->reserve(this->feature_types()); 614 615 // Extract features. 616 for (int i = 0; i < functions_.size(); ++i) { 617 functions_[i]->Evaluate(workspaces, object, args..., result); 618 } 619 } 620 621 private: 622 // Creates and initializes all feature functions in the feature extractor. 623 // 624 // Returns true on success, false otherwise. InitializeFeatureFunctions()625 SAFTM_MUST_USE_RESULT bool InitializeFeatureFunctions() override { 626 // Create all top-level feature functions. 627 for (int i = 0; i < descriptor().feature_size(); ++i) { 628 const FeatureFunctionDescriptor &fd = descriptor().feature(i); 629 Function *function = Function::Instantiate(this, &fd, ""); 630 if (function == nullptr) return false; 631 functions_.push_back(function); 632 } 633 return true; 634 } 635 636 // Collect all feature types used in the feature extractor. GetFeatureTypes(std::vector<FeatureType * > * types)637 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 638 for (int i = 0; i < functions_.size(); ++i) { 639 functions_[i]->GetFeatureTypes(types); 640 } 641 } 642 643 // Top-level feature functions (and variables) in the feature extractor. 644 // Owned. 645 std::vector<Function *> functions_; 646 }; 647 648 } // namespace mobile 649 } // namespace nlp_saft 650 651 #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_ 652