• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/features/char-ngram-feature.h"
18 
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "lang_id/common/fel/feature-types.h"
24 #include "lang_id/common/fel/task-context.h"
25 #include "lang_id/common/lite_base/logging.h"
26 #include "lang_id/common/math/hash.h"
27 #include "lang_id/common/utf8.h"
28 
29 namespace libtextclassifier3 {
30 namespace mobile {
31 namespace lang_id {
32 
Setup(TaskContext * context)33 bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
34   // Parameters in the feature function descriptor.
35   bool include_terminators = GetBoolParameter("include_terminators", false);
36   if (!include_terminators) {
37     SAFTM_LOG(ERROR) << "No support for include_terminators=true";
38     return false;
39   }
40 
41   bool include_spaces = GetBoolParameter("include_spaces", false);
42   if (include_spaces) {
43     SAFTM_LOG(ERROR) << "No support for include_spaces=true";
44     return false;
45   }
46 
47   bool use_equal_ngram_weight = GetBoolParameter("use_equal_weight", false);
48   if (use_equal_ngram_weight) {
49     SAFTM_LOG(ERROR) << "No support for use_equal_weight=true";
50     return false;
51   }
52 
53   ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
54   ngram_size_ = GetIntParameter("size", 3);
55 
56   counts_.assign(ngram_id_dimension_, 0);
57   return true;
58 }
59 
Init(TaskContext * context)60 bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
61   set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
62   return true;
63 }
64 
ComputeNgramCounts(const LightSentence & sentence) const65 int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
66     const LightSentence &sentence) const {
67   SAFTM_CHECK_EQ(counts_.size(), ngram_id_dimension_);
68   SAFTM_CHECK_EQ(non_zero_count_indices_.size(), 0);
69 
70   int total_count = 0;
71 
72   for (const std::string &word : sentence) {
73     const char *const word_end = word.data() + word.size();
74 
75     // Set ngram_start at the start of the current token (word).
76     const char *ngram_start = word.data();
77 
78     // Set ngram_end ngram_size UTF8 characters after ngram_start.  Note: each
79     // UTF8 character contains between 1 and 4 bytes.
80     const char *ngram_end = ngram_start;
81     int num_utf8_chars = 0;
82     do {
83       ngram_end += utils::OneCharLen(ngram_end);
84       num_utf8_chars++;
85     } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
86 
87     if (num_utf8_chars < ngram_size_) {
88       // Current token is so small, it does not contain a single ngram of
89       // ngram_size UTF8 characters.  Not much we can do in this case ...
90       continue;
91     }
92 
93     // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
94     // UTF8 characters from current token.
95     while (true) {
96       // Compute ngram id: hash(ngram) % ngram_id_dimension
97       int ngram_id = (
98           utils::Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start)
99           % ngram_id_dimension_);
100 
101       // Use a reference to the actual count, such that we can both test whether
102       // the count was 0 and increment it without perfoming two lookups.
103       int &ref_to_count_for_ngram = counts_[ngram_id];
104       if (ref_to_count_for_ngram == 0) {
105         non_zero_count_indices_.push_back(ngram_id);
106       }
107       ref_to_count_for_ngram++;
108       total_count++;
109       if (ngram_end >= word_end) {
110         break;
111       }
112 
113       // Advance both ngram_start and ngram_end by one UTF8 character.  This
114       // way, the number of UTF8 characters between them remains constant
115       // (ngram_size).
116       ngram_start += utils::OneCharLen(ngram_start);
117       ngram_end += utils::OneCharLen(ngram_end);
118     }
119   }  // end of loop over tokens.
120 
121   return total_count;
122 }
123 
Evaluate(const WorkspaceSet & workspaces,const LightSentence & sentence,FeatureVector * result) const124 void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
125                                              const LightSentence &sentence,
126                                              FeatureVector *result) const {
127   // NOTE: we use std::* constructs (instead of absl::Mutex & co) to simplify
128   // porting to Android and to avoid pulling in absl (which increases our code
129   // size).
130   std::lock_guard<std::mutex> mlock(state_mutex_);
131 
132   // Find the char ngram counts.
133   int total_count = ComputeNgramCounts(sentence);
134 
135   // Populate the feature vector.
136   const float norm = static_cast<float>(total_count);
137 
138   // TODO(salcianu): explore treating dense vectors (i.e., many non-zero
139   // elements) separately.
140   for (int ngram_id : non_zero_count_indices_) {
141     const float weight = counts_[ngram_id] / norm;
142     FloatFeatureValue value(ngram_id, weight);
143     result->add(feature_type(), value.discrete_value);
144 
145     // Clear up counts_, for the next invocation of Evaluate().
146     counts_[ngram_id] = 0;
147   }
148 
149   // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
150   non_zero_count_indices_.clear();
151 }
152 
153 SAFTM_STATIC_REGISTRATION(ContinuousBagOfNgramsFunction);
154 
155 }  // namespace lang_id
156 }  // namespace mobile
157 }  // namespace nlp_saft
158