• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_TYPES_H_
18 #define LIBTEXTCLASSIFIER_TYPES_H_
19 
20 #include <algorithm>
21 #include <cmath>
22 #include <functional>
23 #include <set>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include "util/base/integral_types.h"
28 
29 #include "util/base/logging.h"
30 
31 namespace libtextclassifier2 {
32 
33 constexpr int kInvalidIndex = -1;
34 
35 // Index for a 0-based array of tokens.
36 using TokenIndex = int;
37 
38 // Index for a 0-based array of codepoints.
39 using CodepointIndex = int;
40 
41 // Marks a span in a sequence of codepoints. The first element is the index of
42 // the first codepoint of the span, and the second element is the index of the
43 // codepoint one past the end of the span.
44 // TODO(b/71982294): Make it a struct.
45 using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
46 
SpansOverlap(const CodepointSpan & a,const CodepointSpan & b)47 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
48   return a.first < b.second && b.first < a.second;
49 }
50 
ValidNonEmptySpan(const CodepointSpan & span)51 inline bool ValidNonEmptySpan(const CodepointSpan& span) {
52   return span.first < span.second && span.first >= 0 && span.second >= 0;
53 }
54 
55 template <typename T>
DoesCandidateConflict(const int considered_candidate,const std::vector<T> & candidates,const std::set<int,std::function<bool (int,int)>> & chosen_indices_set)56 bool DoesCandidateConflict(
57     const int considered_candidate, const std::vector<T>& candidates,
58     const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
59   if (chosen_indices_set.empty()) {
60     return false;
61   }
62 
63   auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
64   // Check conflict on the right.
65   if (conflicting_it != chosen_indices_set.end() &&
66       SpansOverlap(candidates[considered_candidate].span,
67                    candidates[*conflicting_it].span)) {
68     return true;
69   }
70 
71   // Check conflict on the left.
72   // If we can't go more left, there can't be a conflict:
73   if (conflicting_it == chosen_indices_set.begin()) {
74     return false;
75   }
76   // Otherwise move one span left and insert if it doesn't overlap with the
77   // candidate.
78   --conflicting_it;
79   if (!SpansOverlap(candidates[considered_candidate].span,
80                     candidates[*conflicting_it].span)) {
81     return false;
82   }
83 
84   return true;
85 }
86 
87 // Marks a span in a sequence of tokens. The first element is the index of the
88 // first token in the span, and the second element is the index of the token one
89 // past the end of the span.
90 // TODO(b/71982294): Make it a struct.
91 using TokenSpan = std::pair<TokenIndex, TokenIndex>;
92 
93 // Returns the size of the token span. Assumes that the span is valid.
TokenSpanSize(const TokenSpan & token_span)94 inline int TokenSpanSize(const TokenSpan& token_span) {
95   return token_span.second - token_span.first;
96 }
97 
98 // Returns a token span consisting of one token.
SingleTokenSpan(int token_index)99 inline TokenSpan SingleTokenSpan(int token_index) {
100   return {token_index, token_index + 1};
101 }
102 
103 // Returns an intersection of two token spans. Assumes that both spans are valid
104 // and overlapping.
IntersectTokenSpans(const TokenSpan & token_span1,const TokenSpan & token_span2)105 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
106                                      const TokenSpan& token_span2) {
107   return {std::max(token_span1.first, token_span2.first),
108           std::min(token_span1.second, token_span2.second)};
109 }
110 
111 // Returns and expanded token span by adding a certain number of tokens on its
112 // left and on its right.
ExpandTokenSpan(const TokenSpan & token_span,int num_tokens_left,int num_tokens_right)113 inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
114                                  int num_tokens_left, int num_tokens_right) {
115   return {token_span.first - num_tokens_left,
116           token_span.second + num_tokens_right};
117 }
118 
119 // Token holds a token, its position in the original string and whether it was
120 // part of the input span.
121 struct Token {
122   std::string value;
123   CodepointIndex start;
124   CodepointIndex end;
125 
126   // Whether the token is a padding token.
127   bool is_padding;
128 
129   // Default constructor constructs the padding-token.
TokenToken130   Token()
131       : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
132 
TokenToken133   Token(const std::string& arg_value, CodepointIndex arg_start,
134         CodepointIndex arg_end)
135       : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
136 
137   bool operator==(const Token& other) const {
138     return value == other.value && start == other.start && end == other.end &&
139            is_padding == other.is_padding;
140   }
141 
IsContainedInSpanToken142   bool IsContainedInSpan(CodepointSpan span) const {
143     return start >= span.first && end <= span.second;
144   }
145 };
146 
147 // Pretty-printing function for Token.
148 inline logging::LoggingStringStream& operator<<(
149     logging::LoggingStringStream& stream, const Token& token) {
150   if (!token.is_padding) {
151     return stream << "Token(\"" << token.value << "\", " << token.start << ", "
152                   << token.end << ")";
153   } else {
154     return stream << "Token()";
155   }
156 }
157 
158 enum DatetimeGranularity {
159   GRANULARITY_UNKNOWN = -1,  // GRANULARITY_UNKNOWN is used as a proxy for this
160                              // structure being uninitialized.
161   GRANULARITY_YEAR = 0,
162   GRANULARITY_MONTH = 1,
163   GRANULARITY_WEEK = 2,
164   GRANULARITY_DAY = 3,
165   GRANULARITY_HOUR = 4,
166   GRANULARITY_MINUTE = 5,
167   GRANULARITY_SECOND = 6
168 };
169 
170 struct DatetimeParseResult {
171   // The absolute time in milliseconds since the epoch in UTC. This is derived
172   // from the reference time and the fields specified in the text - so it may
173   // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm)
174   int64 time_ms_utc;
175 
176   // The precision of the estimate then in to calculating the milliseconds
177   DatetimeGranularity granularity;
178 
DatetimeParseResultDatetimeParseResult179   DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
180 
DatetimeParseResultDatetimeParseResult181   DatetimeParseResult(int64 arg_time_ms_utc,
182                       DatetimeGranularity arg_granularity)
183       : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
184 
IsSetDatetimeParseResult185   bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
186 
187   bool operator==(const DatetimeParseResult& other) const {
188     return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
189   }
190 };
191 
192 const float kFloatCompareEpsilon = 1e-5;
193 
194 struct DatetimeParseResultSpan {
195   CodepointSpan span;
196   DatetimeParseResult data;
197   float target_classification_score;
198   float priority_score;
199 
200   bool operator==(const DatetimeParseResultSpan& other) const {
201     return span == other.span && data.granularity == other.data.granularity &&
202            data.time_ms_utc == other.data.time_ms_utc &&
203            std::abs(target_classification_score -
204                     other.target_classification_score) < kFloatCompareEpsilon &&
205            std::abs(priority_score - other.priority_score) <
206                kFloatCompareEpsilon;
207   }
208 };
209 
210 // Pretty-printing function for DatetimeParseResultSpan.
211 inline logging::LoggingStringStream& operator<<(
212     logging::LoggingStringStream& stream,
213     const DatetimeParseResultSpan& value) {
214   return stream << "DatetimeParseResultSpan({" << value.span.first << ", "
215                 << value.span.second << "}, {/*time_ms_utc=*/ "
216                 << value.data.time_ms_utc << ", /*granularity=*/ "
217                 << value.data.granularity << "})";
218 }
219 
220 struct ClassificationResult {
221   std::string collection;
222   float score;
223   DatetimeParseResult datetime_parse_result;
224 
225   // Internal score used for conflict resolution.
226   float priority_score;
227 
ClassificationResultClassificationResult228   explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
229 
ClassificationResultClassificationResult230   ClassificationResult(const std::string& arg_collection, float arg_score)
231       : collection(arg_collection),
232         score(arg_score),
233         priority_score(arg_score) {}
234 
ClassificationResultClassificationResult235   ClassificationResult(const std::string& arg_collection, float arg_score,
236                        float arg_priority_score)
237       : collection(arg_collection),
238         score(arg_score),
239         priority_score(arg_priority_score) {}
240 };
241 
242 // Pretty-printing function for ClassificationResult.
243 inline logging::LoggingStringStream& operator<<(
244     logging::LoggingStringStream& stream, const ClassificationResult& result) {
245   return stream << "ClassificationResult(" << result.collection << ", "
246                 << result.score << ")";
247 }
248 
249 // Pretty-printing function for std::vector<ClassificationResult>.
250 inline logging::LoggingStringStream& operator<<(
251     logging::LoggingStringStream& stream,
252     const std::vector<ClassificationResult>& results) {
253   stream = stream << "{\n";
254   for (const ClassificationResult& result : results) {
255     stream = stream << "    " << result << "\n";
256   }
257   stream = stream << "}";
258   return stream;
259 }
260 
261 // Represents a result of Annotate call.
262 struct AnnotatedSpan {
263   // Unicode codepoint indices in the input string.
264   CodepointSpan span = {kInvalidIndex, kInvalidIndex};
265 
266   // Classification result for the span.
267   std::vector<ClassificationResult> classification;
268 };
269 
270 // Pretty-printing function for AnnotatedSpan.
271 inline logging::LoggingStringStream& operator<<(
272     logging::LoggingStringStream& stream, const AnnotatedSpan& span) {
273   std::string best_class;
274   float best_score = -1;
275   if (!span.classification.empty()) {
276     best_class = span.classification[0].collection;
277     best_score = span.classification[0].score;
278   }
279   return stream << "Span(" << span.span.first << ", " << span.span.second
280                 << ", " << best_class << ", " << best_score << ")";
281 }
282 
283 // StringPiece analogue for std::vector<T>.
284 template <class T>
285 class VectorSpan {
286  public:
VectorSpan()287   VectorSpan() : begin_(), end_() {}
VectorSpan(const std::vector<T> & v)288   VectorSpan(const std::vector<T>& v)  // NOLINT(runtime/explicit)
289       : begin_(v.begin()), end_(v.end()) {}
VectorSpan(typename std::vector<T>::const_iterator begin,typename std::vector<T>::const_iterator end)290   VectorSpan(typename std::vector<T>::const_iterator begin,
291              typename std::vector<T>::const_iterator end)
292       : begin_(begin), end_(end) {}
293 
294   const T& operator[](typename std::vector<T>::size_type i) const {
295     return *(begin_ + i);
296   }
297 
size()298   int size() const { return end_ - begin_; }
begin()299   typename std::vector<T>::const_iterator begin() const { return begin_; }
end()300   typename std::vector<T>::const_iterator end() const { return end_; }
data()301   const float* data() const { return &(*begin_); }
302 
303  private:
304   typename std::vector<T>::const_iterator begin_;
305   typename std::vector<T>::const_iterator end_;
306 };
307 
308 struct DateParseData {
309   enum Relation {
310     NEXT = 1,
311     NEXT_OR_SAME = 2,
312     LAST = 3,
313     NOW = 4,
314     TOMORROW = 5,
315     YESTERDAY = 6,
316     PAST = 7,
317     FUTURE = 8
318   };
319 
320   enum RelationType {
321     MONDAY = 1,
322     TUESDAY = 2,
323     WEDNESDAY = 3,
324     THURSDAY = 4,
325     FRIDAY = 5,
326     SATURDAY = 6,
327     SUNDAY = 7,
328     DAY = 8,
329     WEEK = 9,
330     MONTH = 10,
331     YEAR = 11
332   };
333 
334   enum Fields {
335     YEAR_FIELD = 1 << 0,
336     MONTH_FIELD = 1 << 1,
337     DAY_FIELD = 1 << 2,
338     HOUR_FIELD = 1 << 3,
339     MINUTE_FIELD = 1 << 4,
340     SECOND_FIELD = 1 << 5,
341     AMPM_FIELD = 1 << 6,
342     ZONE_OFFSET_FIELD = 1 << 7,
343     DST_OFFSET_FIELD = 1 << 8,
344     RELATION_FIELD = 1 << 9,
345     RELATION_TYPE_FIELD = 1 << 10,
346     RELATION_DISTANCE_FIELD = 1 << 11
347   };
348 
349   enum AMPM { AM = 0, PM = 1 };
350 
351   enum TimeUnit {
352     DAYS = 1,
353     WEEKS = 2,
354     MONTHS = 3,
355     HOURS = 4,
356     MINUTES = 5,
357     SECONDS = 6,
358     YEARS = 7
359   };
360 
361   // Bit mask of fields which have been set on the struct
362   int field_set_mask;
363 
364   // Fields describing absolute date fields.
365   // Year of the date seen in the text match.
366   int year;
367   // Month of the year starting with January = 1.
368   int month;
369   // Day of the month starting with 1.
370   int day_of_month;
371   // Hour of the day with a range of 0-23,
372   // values less than 12 need the AMPM field below or heuristics
373   // to definitively determine the time.
374   int hour;
375   // Hour of the day with a range of 0-59.
376   int minute;
377   // Hour of the day with a range of 0-59.
378   int second;
379   // 0 == AM, 1 == PM
380   int ampm;
381   // Number of hours offset from UTC this date time is in.
382   int zone_offset;
383   // Number of hours offest for DST
384   int dst_offset;
385 
386   // The permutation from now that was made to find the date time.
387   Relation relation;
388   // The unit of measure of the change to the date time.
389   RelationType relation_type;
390   // The number of units of change that were made.
391   int relation_distance;
392 };
393 
394 }  // namespace libtextclassifier2
395 
396 #endif  // LIBTEXTCLASSIFIER_TYPES_H_
397