1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/language-identifier-features.h"
18
19 #include <utility>
20 #include <vector>
21
22 #include "common/feature-extractor.h"
23 #include "common/feature-types.h"
24 #include "common/task-context.h"
25 #include "util/hash/hash.h"
26 #include "util/strings/utf8.h"
27
28 namespace libtextclassifier {
29 namespace nlp_core {
30 namespace lang_id {
31
Setup(TaskContext * context)32 bool ContinuousBagOfNgramsFunction::Setup(TaskContext *context) {
33 // Parameters in the feature function descriptor.
34 ngram_id_dimension_ = GetIntParameter("id_dim", 10000);
35 ngram_size_ = GetIntParameter("size", 3);
36
37 counts_.assign(ngram_id_dimension_, 0);
38 return true;
39 }
40
Init(TaskContext * context)41 bool ContinuousBagOfNgramsFunction::Init(TaskContext *context) {
42 set_feature_type(new NumericFeatureType(name(), ngram_id_dimension_));
43 return true;
44 }
45
ComputeNgramCounts(const LightSentence & sentence) const46 int ContinuousBagOfNgramsFunction::ComputeNgramCounts(
47 const LightSentence &sentence) const {
48 // Invariant 1: counts_.size() == ngram_id_dimension_. Holds at the end of
49 // the constructor. After that, no method changes counts_.size().
50 TC_DCHECK_EQ(counts_.size(), ngram_id_dimension_);
51
52 // Invariant 2: the vector non_zero_count_indices_ is empty. The vector
53 // non_zero_count_indices_ is empty at construction time and gets emptied at
54 // the end of each call to Evaluate(). Hence, this invariant holds at the
55 // beginning of each run of Evaluate(), where the only call to this code takes
56 // place.
57 TC_DCHECK(non_zero_count_indices_.empty());
58
59 int total_count = 0;
60
61 for (int i = 0; i < sentence.num_words(); ++i) {
62 const std::string &word = sentence.word(i);
63 const char *const word_end = word.data() + word.size();
64
65 // Set ngram_start at the start of the current token (word).
66 const char *ngram_start = word.data();
67
68 // Set ngram_end ngram_size UTF8 characters after ngram_start. Note: each
69 // UTF8 character contains between 1 and 4 bytes.
70 const char *ngram_end = ngram_start;
71 int num_utf8_chars = 0;
72 do {
73 ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end);
74 num_utf8_chars++;
75 } while ((num_utf8_chars < ngram_size_) && (ngram_end < word_end));
76
77 if (num_utf8_chars < ngram_size_) {
78 // Current token is so small, it does not contain a single ngram of
79 // ngram_size UTF8 characters. Not much we can do in this case ...
80 continue;
81 }
82
83 // At this point, [ngram_start, ngram_end) is the first ngram of ngram_size
84 // UTF8 characters from current token.
85 while (true) {
86 // Compute ngram_id: hash(ngram) % ngram_id_dimension
87 int ngram_id =
88 (Hash32WithDefaultSeed(ngram_start, ngram_end - ngram_start) %
89 ngram_id_dimension_);
90
91 // Use a reference to the actual count, such that we can both test whether
92 // the count was 0 and increment it without perfoming two lookups.
93 //
94 // Due to the way we compute ngram_id, 0 <= ngram_id < ngram_id_dimension.
95 // Hence, by Invariant 1 (above), the access counts_[ngram_id] is safe.
96 int &ref_to_count_for_ngram = counts_[ngram_id];
97 if (ref_to_count_for_ngram == 0) {
98 non_zero_count_indices_.push_back(ngram_id);
99 }
100 ref_to_count_for_ngram++;
101 total_count++;
102 if (ngram_end >= word_end) {
103 break;
104 }
105
106 // Advance both ngram_start and ngram_end by one UTF8 character. This
107 // way, the number of UTF8 characters between them remains constant
108 // (ngram_size).
109 ngram_start += GetNumBytesForNonZeroUTF8Char(ngram_start);
110 ngram_end += GetNumBytesForNonZeroUTF8Char(ngram_end);
111 }
112 } // end of loop over tokens.
113
114 return total_count;
115 }
116
Evaluate(const WorkspaceSet & workspaces,const LightSentence & sentence,FeatureVector * result) const117 void ContinuousBagOfNgramsFunction::Evaluate(const WorkspaceSet &workspaces,
118 const LightSentence &sentence,
119 FeatureVector *result) const {
120 // Find the char ngram counts.
121 int total_count = ComputeNgramCounts(sentence);
122
123 // Populate the feature vector.
124 const float norm = static_cast<float>(total_count);
125
126 for (int ngram_id : non_zero_count_indices_) {
127 const float weight = counts_[ngram_id] / norm;
128 FloatFeatureValue value(ngram_id, weight);
129 result->add(feature_type(), value.discrete_value);
130
131 // Clear up counts_, for the next invocation of Evaluate().
132 counts_[ngram_id] = 0;
133 }
134
135 // Clear up non_zero_count_indices_, for the next invocation of Evaluate().
136 non_zero_count_indices_.clear();
137 }
138
139 } // namespace lang_id
140 } // namespace nlp_core
141 } // namespace libtextclassifier
142