1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Generic feature extractor for extracting features from objects. The feature 18 // extractor can be used for extracting features from any object. The feature 19 // extractor and feature function classes are template classes that have to 20 // be instantiated for extracting feature from a specific object type. 21 // 22 // A feature extractor consists of a hierarchy of feature functions. Each 23 // feature function extracts one or more feature type and value pairs from the 24 // object. 25 // 26 // The feature extractor has a modular design where new feature functions can be 27 // registered as components. The feature extractor is initialized from a 28 // descriptor represented by a protocol buffer. The feature extractor can also 29 // be initialized from a text-based source specification of the feature 30 // extractor. Feature specification parsers can be added as components. By 31 // default the feature extractor can be read from an ASCII protocol buffer or in 32 // a simple feature modeling language (fml). 33 34 // A feature function is invoked with a focus. Nested feature function can be 35 // invoked with another focus determined by the parent feature function. 36 37 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_ 38 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_ 39 40 #include <stddef.h> 41 42 #include <string> 43 #include <vector> 44 45 #include "lang_id/common/fel/feature-descriptors.h" 46 #include "lang_id/common/fel/feature-types.h" 47 #include "lang_id/common/fel/task-context.h" 48 #include "lang_id/common/fel/workspace.h" 49 #include "lang_id/common/lite_base/attributes.h" 50 #include "lang_id/common/lite_base/integral-types.h" 51 #include "lang_id/common/lite_base/logging.h" 52 #include "lang_id/common/lite_base/macros.h" 53 #include "lang_id/common/registry.h" 54 #include "lang_id/common/stl-util.h" 55 56 namespace libtextclassifier3 { 57 namespace mobile { 58 59 // TODO(djweiss) Clean this up as well. 60 // Use the same type for feature values as is used for predicated. 61 typedef int64 Predicate; 62 typedef Predicate FeatureValue; 63 64 // A union used to represent discrete and continuous feature values. 65 union FloatFeatureValue { 66 public: FloatFeatureValue(FeatureValue v)67 explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {} FloatFeatureValue(uint32 i,float w)68 FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {} 69 FeatureValue discrete_value; 70 struct { 71 uint32 id; 72 float weight; 73 }; 74 }; 75 76 // A feature vector contains feature type and value pairs. 77 class FeatureVector { 78 public: FeatureVector()79 FeatureVector() {} 80 81 // Adds feature type and value pair to feature vector. add(FeatureType * type,FeatureValue value)82 void add(FeatureType *type, FeatureValue value) { 83 features_.emplace_back(type, value); 84 } 85 86 // Removes all elements from the feature vector. clear()87 void clear() { features_.clear(); } 88 89 // Returns the number of elements in the feature vector. size()90 int size() const { return features_.size(); } 91 92 // Reserves space in the underlying feature vector. reserve(int n)93 void reserve(int n) { features_.reserve(n); } 94 95 // Returns feature type for an element in the feature vector. type(int index)96 FeatureType *type(int index) const { return features_[index].type; } 97 98 // Returns feature value for an element in the feature vector. value(int index)99 FeatureValue value(int index) const { return features_[index].value; } 100 101 private: 102 // Structure for holding feature type and value pairs. 103 struct Element { ElementElement104 Element() : type(nullptr), value(-1) {} ElementElement105 Element(FeatureType *t, FeatureValue v) : type(t), value(v) {} 106 107 FeatureType *type; 108 FeatureValue value; 109 }; 110 111 // Array for storing feature vector elements. 112 std::vector<Element> features_; 113 114 SAFTM_DISALLOW_COPY_AND_ASSIGN(FeatureVector); 115 }; 116 117 // The generic feature extractor is the type-independent part of a feature 118 // extractor. This holds the descriptor for the feature extractor and the 119 // collection of feature types used in the feature extractor. The feature 120 // types are not available until FeatureExtractor<>::Init() has been called. 121 class GenericFeatureExtractor { 122 public: 123 GenericFeatureExtractor(); 124 virtual ~GenericFeatureExtractor(); 125 126 // Initializes the feature extractor from the FEL specification |source|. 127 // 128 // Returns true on success, false otherwise (e.g., FEL syntax error). 129 SAFTM_MUST_USE_RESULT bool Parse(const std::string &source); 130 131 // Returns the feature extractor descriptor. descriptor()132 const FeatureExtractorDescriptor &descriptor() const { return descriptor_; } mutable_descriptor()133 FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; } 134 135 // Returns the number of feature types in the feature extractor. Invalid 136 // before Init() has been called. feature_types()137 int feature_types() const { return feature_types_.size(); } 138 139 protected: 140 // Initializes the feature types used by the extractor. Called from 141 // FeatureExtractor<>::Init(). 142 // 143 // Returns true on success, false otherwise. 144 SAFTM_MUST_USE_RESULT bool InitializeFeatureTypes(); 145 146 private: 147 // Initializes the top-level feature functions. 148 // 149 // Returns true on success, false otherwise. 150 SAFTM_MUST_USE_RESULT virtual bool InitializeFeatureFunctions() = 0; 151 152 // Returns all feature types used by the extractor. The feature types are 153 // added to the result array. 154 virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0; 155 156 // Descriptor for the feature extractor. This is a protocol buffer that 157 // contains all the information about the feature extractor. The feature 158 // functions are initialized from the information in the descriptor. 159 FeatureExtractorDescriptor descriptor_; 160 161 // All feature types used by the feature extractor. The collection of all the 162 // feature types describes the feature space of the feature set produced by 163 // the feature extractor. Not owned. 164 std::vector<FeatureType *> feature_types_; 165 }; 166 167 // The generic feature function is the type-independent part of a feature 168 // function. Each feature function is associated with the descriptor that it is 169 // instantiated from. The feature types associated with this feature function 170 // will be established by the time FeatureExtractor<>::Init() completes. 171 class GenericFeatureFunction { 172 public: 173 // A feature value that represents the absence of a value. 174 static constexpr FeatureValue kNone = -1; 175 176 GenericFeatureFunction(); 177 virtual ~GenericFeatureFunction(); 178 179 // Sets up the feature function. NB: FeatureTypes of nested functions are not 180 // guaranteed to be available until Init(). 181 // 182 // Returns true on success, false otherwise. Setup(TaskContext * context)183 SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context) { 184 return true; 185 } 186 187 // Initializes the feature function. NB: The FeatureType of this function must 188 // be established when this method completes. 189 // 190 // Returns true on success, false otherwise. Init(TaskContext * context)191 SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context) { return true; } 192 193 // Requests workspaces from a registry to obtain indices into a WorkspaceSet 194 // for any Workspace objects used by this feature function. NB: This will be 195 // called after Init(), so it can depend on resources and arguments. RequestWorkspaces(WorkspaceRegistry * registry)196 virtual void RequestWorkspaces(WorkspaceRegistry *registry) {} 197 198 // Appends the feature types produced by the feature function to types. The 199 // default implementation appends feature_type(), if non-null. Invalid 200 // before Init() has been called. 201 virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const; 202 203 // Returns the feature type for feature produced by this feature function. If 204 // the feature function produces features of different types this returns 205 // null. Invalid before Init() has been called. 206 virtual FeatureType *GetFeatureType() const; 207 208 // Returns value of parameter |name| from the feature function descriptor. 209 // If the parameter is not present, returns the indicated |default_value|. 210 std::string GetParameter(const std::string &name, 211 const std::string &default_value) const; 212 213 // Returns value of int parameter |name| from feature function descriptor. 214 // If the parameter is not present, or its value can't be parsed as an int, 215 // returns |default_value|. 216 int GetIntParameter(const std::string &name, int default_value) const; 217 218 // Returns value of bool parameter |name| from feature function descriptor. 219 // If the parameter is not present, or its value is not "true" or "false", 220 // returns |default_value|. NOTE: this method is case sensitive, it doesn't 221 // do any lower-casing. 222 bool GetBoolParameter(const std::string &name, bool default_value) const; 223 224 // Returns the FEL function description for the feature function, i.e. the 225 // name and parameters without the nested features. FunctionName()226 std::string FunctionName() const { 227 std::string output; 228 ToFELFunction(*descriptor_, &output); 229 return output; 230 } 231 232 // Returns the prefix for nested feature functions. This is the prefix of this 233 // feature function concatenated with the feature function name. SubPrefix()234 std::string SubPrefix() const { 235 return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName(); 236 } 237 238 // Returns/sets the feature extractor this function belongs to. extractor()239 const GenericFeatureExtractor *extractor() const { return extractor_; } set_extractor(const GenericFeatureExtractor * extractor)240 void set_extractor(const GenericFeatureExtractor *extractor) { 241 extractor_ = extractor; 242 } 243 244 // Returns/sets the feature function descriptor. descriptor()245 const FeatureFunctionDescriptor *descriptor() const { return descriptor_; } set_descriptor(const FeatureFunctionDescriptor * descriptor)246 void set_descriptor(const FeatureFunctionDescriptor *descriptor) { 247 descriptor_ = descriptor; 248 } 249 250 // Returns a descriptive name for the feature function. The name is taken from 251 // the descriptor for the feature function. If the name is empty or the 252 // feature function is a variable the name is the FEL representation of the 253 // feature, including the prefix. 254 std::string name() const; 255 256 // Returns the argument from the feature function descriptor. It defaults to 257 // 0 if the argument has not been specified. argument()258 int argument() const { 259 return descriptor_->has_argument() ? descriptor_->argument() : 0; 260 } 261 262 // Returns/sets/clears function name prefix. prefix()263 const std::string &prefix() const { return prefix_; } set_prefix(const std::string & prefix)264 void set_prefix(const std::string &prefix) { prefix_ = prefix; } 265 266 protected: 267 // Returns the feature type for single-type feature functions. feature_type()268 FeatureType *feature_type() const { return feature_type_; } 269 270 // Sets the feature type for single-type feature functions. This takes 271 // ownership of feature_type. Can only be called once. set_feature_type(FeatureType * feature_type)272 void set_feature_type(FeatureType *feature_type) { 273 SAFTM_CHECK_EQ(feature_type_, nullptr); 274 feature_type_ = feature_type; 275 } 276 277 private: 278 // Feature extractor this feature function belongs to. Not owned. Set to a 279 // pointer != nullptr as soon as this object is created by Instantiate(). 280 // Normal methods can safely assume this is != nullptr. 281 const GenericFeatureExtractor *extractor_ = nullptr; 282 283 // Descriptor for feature function. Not owned. Set to a pointer != nullptr 284 // as soon as this object is created by Instantiate(). Normal methods can 285 // safely assume this is != nullptr. 286 const FeatureFunctionDescriptor *descriptor_ = nullptr; 287 288 // Feature type for features produced by this feature function. If the 289 // feature function produces features of multiple feature types this is null 290 // and the feature function must return it's feature types in 291 // GetFeatureTypes(). Owned. 292 FeatureType *feature_type_ = nullptr; 293 294 // Prefix used for sub-feature types of this function. 295 std::string prefix_; 296 }; 297 298 // Feature function that can extract features from an object. Templated on 299 // two type arguments: 300 // 301 // OBJ: The "object" from which features are extracted; e.g., a sentence. This 302 // should be a plain type, rather than a reference or pointer. 303 // 304 // ARGS: A set of 0 or more types that are used to "index" into some part of the 305 // object that should be extracted, e.g. an int token index for a sentence 306 // object. This should not be a reference type. 307 template <class OBJ, class... ARGS> 308 class FeatureFunction 309 : public GenericFeatureFunction, 310 public RegisterableClass<FeatureFunction<OBJ, ARGS...> > { 311 public: 312 using Self = FeatureFunction<OBJ, ARGS...>; 313 314 // Preprocesses the object. This will be called prior to calling Evaluate() 315 // or Compute() on that object. Preprocess(WorkspaceSet * workspaces,const OBJ * object)316 virtual void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const {} 317 318 // Appends features computed from the object and focus to the result. The 319 // default implementation delegates to Compute(), adding a single value if 320 // available. Multi-valued feature functions must override this method. Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)321 virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, 322 ARGS... args, FeatureVector *result) const { 323 FeatureValue value = Compute(workspaces, object, args...); 324 if (value != kNone) result->add(feature_type(), value); 325 } 326 327 // Returns a feature value computed from the object and focus, or kNone if no 328 // value is computed. Single-valued feature functions only need to override 329 // this method. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args)330 virtual FeatureValue Compute(const WorkspaceSet &workspaces, 331 const OBJ &object, ARGS... args) const { 332 return kNone; 333 } 334 335 // Instantiates a new feature function in a feature extractor from a feature 336 // descriptor. 337 // 338 // Returns a pointer to the newly-created object if everything goes well. 339 // Returns nullptr if the feature function could not be instantiated (e.g., if 340 // the function with that name is not registered; this usually happens because 341 // the relevant cc_library was not linked-in). Instantiate(const GenericFeatureExtractor * extractor,const FeatureFunctionDescriptor * fd,const std::string & prefix)342 static Self *Instantiate(const GenericFeatureExtractor *extractor, 343 const FeatureFunctionDescriptor *fd, 344 const std::string &prefix) { 345 Self *f = Self::Create(fd->type()); 346 if (f != nullptr) { 347 f->set_extractor(extractor); 348 f->set_descriptor(fd); 349 f->set_prefix(prefix); 350 } 351 return f; 352 } 353 354 private: 355 // Special feature function class for resolving variable references. The type 356 // of the feature function is used for resolving the variable reference. When 357 // evaluated it will either get the feature value(s) from the variable portion 358 // of the feature vector, if present, or otherwise it will call the referenced 359 // feature extractor function directly to extract the feature(s). 360 class Reference; 361 }; 362 363 // Base class for features with nested feature functions. The nested functions 364 // are of type NES, which may be different from the type of the parent function. 365 // NB: NestedFeatureFunction will ensure that all initialization of nested 366 // functions takes place during Setup() and Init() -- after the nested features 367 // are initialized, the parent feature is initialized via SetupNested() and 368 // InitNested(). Alternatively, a derived classes that overrides Setup() and 369 // Init() directly should call Parent::Setup(), Parent::Init(), etc. first. 370 // 371 // Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or 372 // Compute, since the nested functions may be of a different type. 373 template <class NES, class OBJ, class... ARGS> 374 class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> { 375 public: 376 using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>; 377 378 // Clean up nested functions. ~NestedFeatureFunction()379 ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); } 380 381 // By default, just appends the nested feature types. GetFeatureTypes(std::vector<FeatureType * > * types)382 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 383 SAFTM_CHECK(!this->nested().empty()) 384 << "Nested features require nested features to be defined."; 385 for (auto *function : nested_) function->GetFeatureTypes(types); 386 } 387 388 // Sets up the nested features. 389 // 390 // Returns true on success, false otherwise. Setup(TaskContext * context)391 SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override { 392 bool success = CreateNested(this->extractor(), this->descriptor(), &nested_, 393 this->SubPrefix()); 394 if (!success) return false; 395 for (auto *function : nested_) { 396 if (!function->Setup(context)) return false; 397 } 398 if (!SetupNested(context)) return false; 399 return true; 400 } 401 402 // Sets up this NestedFeatureFunction specifically. 403 // 404 // Returns true on success, false otherwise. SetupNested(TaskContext * context)405 SAFTM_MUST_USE_RESULT virtual bool SetupNested(TaskContext *context) { 406 return true; 407 } 408 409 // Initializes the nested features. 410 // 411 // Returns true on success, false otherwise. Init(TaskContext * context)412 SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override { 413 for (auto *function : nested_) { 414 if (!function->Init(context)) return false; 415 } 416 if (!InitNested(context)) return false; 417 return true; 418 } 419 420 // Initializes this NestedFeatureFunction specifically. 421 // 422 // Returns true on success, false otherwise. InitNested(TaskContext * context)423 SAFTM_MUST_USE_RESULT virtual bool InitNested(TaskContext *context) { 424 return true; 425 } 426 427 // Gets all the workspaces needed for the nested functions. RequestWorkspaces(WorkspaceRegistry * registry)428 void RequestWorkspaces(WorkspaceRegistry *registry) override { 429 for (auto *function : nested_) function->RequestWorkspaces(registry); 430 } 431 432 // Returns the list of nested feature functions. nested()433 const std::vector<NES *> &nested() const { return nested_; } 434 435 // Instantiates nested feature functions for a feature function. Creates and 436 // initializes one feature function for each sub-descriptor in the feature 437 // descriptor. 438 // 439 // Returns true on success, false otherwise. CreateNested(const GenericFeatureExtractor * extractor,const FeatureFunctionDescriptor * fd,std::vector<NES * > * functions,const std::string & prefix)440 SAFTM_MUST_USE_RESULT static bool CreateNested( 441 const GenericFeatureExtractor *extractor, 442 const FeatureFunctionDescriptor *fd, std::vector<NES *> *functions, 443 const std::string &prefix) { 444 for (int i = 0; i < fd->feature_size(); ++i) { 445 const FeatureFunctionDescriptor &sub = fd->feature(i); 446 NES *f = NES::Instantiate(extractor, &sub, prefix); 447 if (f == nullptr) return false; 448 functions->push_back(f); 449 } 450 return true; 451 } 452 453 protected: 454 // The nested feature functions, if any, in order of declaration in the 455 // feature descriptor. Owned. 456 std::vector<NES *> nested_; 457 }; 458 459 // Base class for a nested feature function that takes nested features with the 460 // same signature as these features, i.e. a meta feature. For this class, we can 461 // provide preprocessing of the nested features. 462 template <class OBJ, class... ARGS> 463 class MetaFeatureFunction 464 : public NestedFeatureFunction<FeatureFunction<OBJ, ARGS...>, OBJ, 465 ARGS...> { 466 public: 467 // Preprocesses using the nested features. Preprocess(WorkspaceSet * workspaces,const OBJ * object)468 void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override { 469 for (auto *function : this->nested_) { 470 function->Preprocess(workspaces, object); 471 } 472 } 473 }; 474 475 // Template for a special type of locator: The locator of type 476 // FeatureFunction<OBJ, ARGS...> calls nested functions of type 477 // FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is 478 // responsible for translating by providing the following: 479 // 480 // // Gets the new additional focus. 481 // IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object); 482 // 483 // This is useful to e.g. add a token focus to a parser state based on some 484 // desired property of that state. 485 template <class DER, class OBJ, class IDX, class... ARGS> 486 class FeatureAddFocusLocator 487 : public NestedFeatureFunction<FeatureFunction<OBJ, IDX, ARGS...>, OBJ, 488 ARGS...> { 489 public: Preprocess(WorkspaceSet * workspaces,const OBJ * object)490 void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const override { 491 for (auto *function : this->nested_) { 492 function->Preprocess(workspaces, object); 493 } 494 } 495 Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)496 void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, 497 FeatureVector *result) const override { 498 IDX focus = 499 static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); 500 for (auto *function : this->nested()) { 501 function->Evaluate(workspaces, object, focus, args..., result); 502 } 503 } 504 505 // Returns the first nested feature's computed value. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args)506 FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, 507 ARGS... args) const override { 508 IDX focus = 509 static_cast<const DER *>(this)->GetFocus(workspaces, object, args...); 510 return this->nested()[0]->Compute(workspaces, object, focus, args...); 511 } 512 }; 513 514 // CRTP feature locator class. This is a meta feature that modifies ARGS and 515 // then calls the nested feature functions with the modified ARGS. Note that in 516 // order for this template to work correctly, all of ARGS must be types for 517 // which the reference operator & can be interpreted as a pointer to the 518 // argument. The derived class DER must implement the UpdateFocus method which 519 // takes pointers to the ARGS arguments: 520 // 521 // // Updates the current arguments. 522 // void UpdateArgs(const OBJ &object, ARGS *...args) const; 523 template <class DER, class OBJ, class... ARGS> 524 class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> { 525 public: 526 // Feature locators have an additional check that there is no intrinsic type. GetFeatureTypes(std::vector<FeatureType * > * types)527 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 528 SAFTM_CHECK_EQ(this->feature_type(), nullptr) 529 << "FeatureLocators should not have an intrinsic type."; 530 MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types); 531 } 532 533 // Evaluates the locator. Evaluate(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)534 void Evaluate(const WorkspaceSet &workspaces, const OBJ &object, ARGS... args, 535 FeatureVector *result) const override { 536 static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); 537 for (auto *function : this->nested()) { 538 function->Evaluate(workspaces, object, args..., result); 539 } 540 } 541 542 // Returns the first nested feature's computed value. Compute(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args)543 FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object, 544 ARGS... args) const override { 545 static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...); 546 return this->nested()[0]->Compute(workspaces, object, args...); 547 } 548 }; 549 550 // Feature extractor for extracting features from objects of a certain class. 551 // Template type parameters are as defined for FeatureFunction. 552 template <class OBJ, class... ARGS> 553 class FeatureExtractor : public GenericFeatureExtractor { 554 public: 555 // Feature function type for top-level functions in the feature extractor. 556 typedef FeatureFunction<OBJ, ARGS...> Function; 557 typedef FeatureExtractor<OBJ, ARGS...> Self; 558 559 // Feature locator type for the feature extractor. 560 template <class DER> 561 using Locator = FeatureLocator<DER, OBJ, ARGS...>; 562 563 // Initializes feature extractor. FeatureExtractor()564 FeatureExtractor() {} 565 ~FeatureExtractor()566 ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); } 567 568 // Sets up the feature extractor. Note that only top-level functions exist 569 // until Setup() is called. This does not take ownership over the context, 570 // which must outlive this. 571 // 572 // Returns true on success, false otherwise. Setup(TaskContext * context)573 SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) { 574 for (Function *function : functions_) { 575 if (!function->Setup(context)) return false; 576 } 577 return true; 578 } 579 580 // Initializes the feature extractor. Must be called after Setup(). This 581 // does not take ownership over the context, which must outlive this. 582 // 583 // Returns true on success, false otherwise. Init(TaskContext * context)584 SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) { 585 for (Function *function : functions_) { 586 if (!function->Init(context)) return false; 587 } 588 if (!this->InitializeFeatureTypes()) return false; 589 return true; 590 } 591 592 // Requests workspaces from the registry. Must be called after Init(), and 593 // before Preprocess(). Does not take ownership over registry. This should be 594 // the same registry used to initialize the WorkspaceSet used in Preprocess() 595 // and ExtractFeatures(). NB: This is a different ordering from that used in 596 // SentenceFeatureRepresentation style feature computation. RequestWorkspaces(WorkspaceRegistry * registry)597 void RequestWorkspaces(WorkspaceRegistry *registry) { 598 for (auto *function : functions_) function->RequestWorkspaces(registry); 599 } 600 601 // Preprocesses the object using feature functions for the phase. Must be 602 // called before any calls to ExtractFeatures() on that object and phase. Preprocess(WorkspaceSet * workspaces,const OBJ * object)603 void Preprocess(WorkspaceSet *workspaces, const OBJ *object) const { 604 for (Function *function : functions_) { 605 function->Preprocess(workspaces, object); 606 } 607 } 608 609 // Extracts features from an object with a focus. This invokes all the 610 // top-level feature functions in the feature extractor. Only feature 611 // functions belonging to the specified phase are invoked. ExtractFeatures(const WorkspaceSet & workspaces,const OBJ & object,ARGS...args,FeatureVector * result)612 void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object, 613 ARGS... args, FeatureVector *result) const { 614 result->reserve(this->feature_types()); 615 616 // Extract features. 617 for (int i = 0; i < functions_.size(); ++i) { 618 functions_[i]->Evaluate(workspaces, object, args..., result); 619 } 620 } 621 622 private: 623 // Creates and initializes all feature functions in the feature extractor. 624 // 625 // Returns true on success, false otherwise. InitializeFeatureFunctions()626 SAFTM_MUST_USE_RESULT bool InitializeFeatureFunctions() override { 627 // Create all top-level feature functions. 628 for (int i = 0; i < descriptor().feature_size(); ++i) { 629 const FeatureFunctionDescriptor &fd = descriptor().feature(i); 630 Function *function = Function::Instantiate(this, &fd, ""); 631 if (function == nullptr) return false; 632 functions_.push_back(function); 633 } 634 return true; 635 } 636 637 // Collect all feature types used in the feature extractor. GetFeatureTypes(std::vector<FeatureType * > * types)638 void GetFeatureTypes(std::vector<FeatureType *> *types) const override { 639 for (int i = 0; i < functions_.size(); ++i) { 640 functions_[i]->GetFeatureTypes(types); 641 } 642 } 643 644 // Top-level feature functions (and variables) in the feature extractor. 645 // Owned. 646 std::vector<Function *> functions_; 647 }; 648 649 } // namespace mobile 650 } // namespace nlp_saft 651 652 #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FEL_FEATURE_EXTRACTOR_H_ 653