• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
18 #define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
19 
20 #include <time.h>
21 #include <algorithm>
22 #include <cmath>
23 #include <functional>
24 #include <map>
25 #include <set>
26 #include <string>
27 #include <utility>
28 #include <vector>
29 
30 #include "annotator/entity-data_generated.h"
31 #include "utils/base/integral_types.h"
32 #include "utils/base/logging.h"
33 #include "utils/flatbuffers.h"
34 #include "utils/variant.h"
35 
36 namespace libtextclassifier3 {
37 
38 constexpr int kInvalidIndex = -1;
39 
40 // Index for a 0-based array of tokens.
41 using TokenIndex = int;
42 
43 // Index for a 0-based array of codepoints.
44 using CodepointIndex = int;
45 
46 // Marks a span in a sequence of codepoints. The first element is the index of
47 // the first codepoint of the span, and the second element is the index of the
48 // codepoint one past the end of the span.
49 // TODO(b/71982294): Make it a struct.
50 using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
51 
SpansOverlap(const CodepointSpan & a,const CodepointSpan & b)52 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
53   return a.first < b.second && b.first < a.second;
54 }
55 
ValidNonEmptySpan(const CodepointSpan & span)56 inline bool ValidNonEmptySpan(const CodepointSpan& span) {
57   return span.first < span.second && span.first >= 0 && span.second >= 0;
58 }
59 
60 template <typename T>
DoesCandidateConflict(const int considered_candidate,const std::vector<T> & candidates,const std::set<int,std::function<bool (int,int)>> & chosen_indices_set)61 bool DoesCandidateConflict(
62     const int considered_candidate, const std::vector<T>& candidates,
63     const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
64   if (chosen_indices_set.empty()) {
65     return false;
66   }
67 
68   auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
69   // Check conflict on the right.
70   if (conflicting_it != chosen_indices_set.end() &&
71       SpansOverlap(candidates[considered_candidate].span,
72                    candidates[*conflicting_it].span)) {
73     return true;
74   }
75 
76   // Check conflict on the left.
77   // If we can't go more left, there can't be a conflict:
78   if (conflicting_it == chosen_indices_set.begin()) {
79     return false;
80   }
81   // Otherwise move one span left and insert if it doesn't overlap with the
82   // candidate.
83   --conflicting_it;
84   if (!SpansOverlap(candidates[considered_candidate].span,
85                     candidates[*conflicting_it].span)) {
86     return false;
87   }
88 
89   return true;
90 }
91 
92 // Marks a span in a sequence of tokens. The first element is the index of the
93 // first token in the span, and the second element is the index of the token one
94 // past the end of the span.
95 // TODO(b/71982294): Make it a struct.
96 using TokenSpan = std::pair<TokenIndex, TokenIndex>;
97 
98 // Returns the size of the token span. Assumes that the span is valid.
TokenSpanSize(const TokenSpan & token_span)99 inline int TokenSpanSize(const TokenSpan& token_span) {
100   return token_span.second - token_span.first;
101 }
102 
103 // Returns a token span consisting of one token.
SingleTokenSpan(int token_index)104 inline TokenSpan SingleTokenSpan(int token_index) {
105   return {token_index, token_index + 1};
106 }
107 
108 // Returns an intersection of two token spans. Assumes that both spans are valid
109 // and overlapping.
IntersectTokenSpans(const TokenSpan & token_span1,const TokenSpan & token_span2)110 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
111                                      const TokenSpan& token_span2) {
112   return {std::max(token_span1.first, token_span2.first),
113           std::min(token_span1.second, token_span2.second)};
114 }
115 
116 // Returns and expanded token span by adding a certain number of tokens on its
117 // left and on its right.
ExpandTokenSpan(const TokenSpan & token_span,int num_tokens_left,int num_tokens_right)118 inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
119                                  int num_tokens_left, int num_tokens_right) {
120   return {token_span.first - num_tokens_left,
121           token_span.second + num_tokens_right};
122 }
123 
124 // Token holds a token, its position in the original string and whether it was
125 // part of the input span.
126 struct Token {
127   std::string value;
128   CodepointIndex start;
129   CodepointIndex end;
130 
131   // Whether the token is a padding token.
132   bool is_padding;
133 
134   // Default constructor constructs the padding-token.
TokenToken135   Token()
136       : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
137 
TokenToken138   Token(const std::string& arg_value, CodepointIndex arg_start,
139         CodepointIndex arg_end)
140       : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
141 
142   bool operator==(const Token& other) const {
143     return value == other.value && start == other.start && end == other.end &&
144            is_padding == other.is_padding;
145   }
146 
IsContainedInSpanToken147   bool IsContainedInSpan(CodepointSpan span) const {
148     return start >= span.first && end <= span.second;
149   }
150 };
151 
152 // Pretty-printing function for Token.
153 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
154                                          const Token& token);
155 
156 enum DatetimeGranularity {
157   GRANULARITY_UNKNOWN = -1,  // GRANULARITY_UNKNOWN is used as a proxy for this
158                              // structure being uninitialized.
159   GRANULARITY_YEAR = 0,
160   GRANULARITY_MONTH = 1,
161   GRANULARITY_WEEK = 2,
162   GRANULARITY_DAY = 3,
163   GRANULARITY_HOUR = 4,
164   GRANULARITY_MINUTE = 5,
165   GRANULARITY_SECOND = 6
166 };
167 
168 struct DatetimeParseResult {
169   // The absolute time in milliseconds since the epoch in UTC.
170   int64 time_ms_utc;
171 
172   // The precision of the estimate then in to calculating the milliseconds
173   DatetimeGranularity granularity;
174 
DatetimeParseResultDatetimeParseResult175   DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
176 
DatetimeParseResultDatetimeParseResult177   DatetimeParseResult(int64 arg_time_ms_utc,
178                       DatetimeGranularity arg_granularity)
179       : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
180 
IsSetDatetimeParseResult181   bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
182 
183   bool operator==(const DatetimeParseResult& other) const {
184     return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
185   }
186 };
187 
188 const float kFloatCompareEpsilon = 1e-5;
189 
190 struct DatetimeParseResultSpan {
191   CodepointSpan span;
192   std::vector<DatetimeParseResult> data;
193   float target_classification_score;
194   float priority_score;
195 
196   bool operator==(const DatetimeParseResultSpan& other) const {
197     return span == other.span && data == other.data &&
198            std::abs(target_classification_score -
199                     other.target_classification_score) < kFloatCompareEpsilon &&
200            std::abs(priority_score - other.priority_score) <
201                kFloatCompareEpsilon;
202   }
203 };
204 
205 // Pretty-printing function for DatetimeParseResultSpan.
206 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
207                                          const DatetimeParseResultSpan& value);
208 
209 struct ClassificationResult {
210   std::string collection;
211   float score;
212   DatetimeParseResult datetime_parse_result;
213   std::string serialized_knowledge_result;
214   std::string contact_name, contact_given_name, contact_nickname,
215       contact_email_address, contact_phone_number, contact_id;
216   std::string app_name, app_package_name;
217   int64 numeric_value;
218 
219   // Length of the parsed duration in milliseconds.
220   int64 duration_ms;
221 
222   // Internal score used for conflict resolution.
223   float priority_score;
224 
225 
226   // Entity data information.
227   std::string serialized_entity_data;
entity_dataClassificationResult228   const EntityData* entity_data() {
229     return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
230                                                serialized_entity_data.size());
231   }
232 
ClassificationResultClassificationResult233   explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
234 
ClassificationResultClassificationResult235   ClassificationResult(const std::string& arg_collection, float arg_score)
236       : collection(arg_collection),
237         score(arg_score),
238         priority_score(arg_score) {}
239 
ClassificationResultClassificationResult240   ClassificationResult(const std::string& arg_collection, float arg_score,
241                        float arg_priority_score)
242       : collection(arg_collection),
243         score(arg_score),
244         priority_score(arg_priority_score) {}
245 };
246 
247 // Pretty-printing function for ClassificationResult.
248 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
249                                          const ClassificationResult& result);
250 
251 // Pretty-printing function for std::vector<ClassificationResult>.
252 logging::LoggingStringStream& operator<<(
253     logging::LoggingStringStream& stream,
254     const std::vector<ClassificationResult>& results);
255 
256 // Represents a result of Annotate call.
257 struct AnnotatedSpan {
258   enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME };
259 
260   // Unicode codepoint indices in the input string.
261   CodepointSpan span = {kInvalidIndex, kInvalidIndex};
262 
263   // Classification result for the span.
264   std::vector<ClassificationResult> classification;
265 
266   // The source of the annotation, used in conflict resolution.
267   Source source = Source::OTHER;
268 
269   AnnotatedSpan() = default;
270 
AnnotatedSpanAnnotatedSpan271   AnnotatedSpan(CodepointSpan arg_span,
272                 std::vector<ClassificationResult> arg_classification)
273       : span(arg_span), classification(std::move(arg_classification)) {}
274 };
275 
276 // Pretty-printing function for AnnotatedSpan.
277 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
278                                          const AnnotatedSpan& span);
279 
280 // StringPiece analogue for std::vector<T>.
281 template <class T>
282 class VectorSpan {
283  public:
VectorSpan()284   VectorSpan() : begin_(), end_() {}
VectorSpan(const std::vector<T> & v)285   VectorSpan(const std::vector<T>& v)  // NOLINT(runtime/explicit)
286       : begin_(v.begin()), end_(v.end()) {}
VectorSpan(typename std::vector<T>::const_iterator begin,typename std::vector<T>::const_iterator end)287   VectorSpan(typename std::vector<T>::const_iterator begin,
288              typename std::vector<T>::const_iterator end)
289       : begin_(begin), end_(end) {}
290 
291   const T& operator[](typename std::vector<T>::size_type i) const {
292     return *(begin_ + i);
293   }
294 
size()295   int size() const { return end_ - begin_; }
begin()296   typename std::vector<T>::const_iterator begin() const { return begin_; }
end()297   typename std::vector<T>::const_iterator end() const { return end_; }
data()298   const float* data() const { return &(*begin_); }
299 
300  private:
301   typename std::vector<T>::const_iterator begin_;
302   typename std::vector<T>::const_iterator end_;
303 };
304 
305 struct DateParseData {
306   enum class Relation {
307     UNSPECIFIED = 0,
308     NEXT = 1,
309     NEXT_OR_SAME = 2,
310     LAST = 3,
311     NOW = 4,
312     TOMORROW = 5,
313     YESTERDAY = 6,
314     PAST = 7,
315     FUTURE = 8
316   };
317 
318   enum class RelationType {
319     UNSPECIFIED = 0,
320     SUNDAY = 1,
321     MONDAY = 2,
322     TUESDAY = 3,
323     WEDNESDAY = 4,
324     THURSDAY = 5,
325     FRIDAY = 6,
326     SATURDAY = 7,
327     DAY = 8,
328     WEEK = 9,
329     MONTH = 10,
330     YEAR = 11,
331     HOUR = 12,
332     MINUTE = 13,
333     SECOND = 14,
334   };
335 
336   enum Fields {
337     YEAR_FIELD = 1 << 0,
338     MONTH_FIELD = 1 << 1,
339     DAY_FIELD = 1 << 2,
340     HOUR_FIELD = 1 << 3,
341     MINUTE_FIELD = 1 << 4,
342     SECOND_FIELD = 1 << 5,
343     AMPM_FIELD = 1 << 6,
344     ZONE_OFFSET_FIELD = 1 << 7,
345     DST_OFFSET_FIELD = 1 << 8,
346     RELATION_FIELD = 1 << 9,
347     RELATION_TYPE_FIELD = 1 << 10,
348     RELATION_DISTANCE_FIELD = 1 << 11
349   };
350 
351   enum class AMPM { AM = 0, PM = 1 };
352 
353   enum class TimeUnit {
354     DAYS = 1,
355     WEEKS = 2,
356     MONTHS = 3,
357     HOURS = 4,
358     MINUTES = 5,
359     SECONDS = 6,
360     YEARS = 7
361   };
362 
363   // Bit mask of fields which have been set on the struct
364   int field_set_mask = 0;
365 
366   // Fields describing absolute date fields.
367   // Year of the date seen in the text match.
368   int year = 0;
369   // Month of the year starting with January = 1.
370   int month = 0;
371   // Day of the month starting with 1.
372   int day_of_month = 0;
373   // Hour of the day with a range of 0-23,
374   // values less than 12 need the AMPM field below or heuristics
375   // to definitively determine the time.
376   int hour = 0;
377   // Hour of the day with a range of 0-59.
378   int minute = 0;
379   // Hour of the day with a range of 0-59.
380   int second = 0;
381   // 0 == AM, 1 == PM
382   AMPM ampm = AMPM::AM;
383   // Number of hours offset from UTC this date time is in.
384   int zone_offset = 0;
385   // Number of hours offest for DST
386   int dst_offset = 0;
387 
388   // The permutation from now that was made to find the date time.
389   Relation relation = Relation::UNSPECIFIED;
390   // The unit of measure of the change to the date time.
391   RelationType relation_type = RelationType::UNSPECIFIED;
392   // The number of units of change that were made.
393   int relation_distance = 0;
394 
395   DateParseData() = default;
396 
DateParseDataDateParseData397   DateParseData(int field_set_mask, int year, int month, int day_of_month,
398                 int hour, int minute, int second, AMPM ampm, int zone_offset,
399                 int dst_offset, Relation relation, RelationType relation_type,
400                 int relation_distance) {
401     this->field_set_mask = field_set_mask;
402     this->year = year;
403     this->month = month;
404     this->day_of_month = day_of_month;
405     this->hour = hour;
406     this->minute = minute;
407     this->second = second;
408     this->ampm = ampm;
409     this->zone_offset = zone_offset;
410     this->dst_offset = dst_offset;
411     this->relation = relation;
412     this->relation_type = relation_type;
413     this->relation_distance = relation_distance;
414   }
415 };
416 
417 // Pretty-printing function for DateParseData.
418 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
419                                          const DateParseData& data);
420 
421 }  // namespace libtextclassifier3
422 
423 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
424