• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
18 #define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
19 
20 #include <time.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <functional>
25 #include <map>
26 #include <set>
27 #include <string>
28 #include <unordered_set>
29 #include <utility>
30 #include <vector>
31 
32 #include "annotator/entity-data_generated.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "utils/base/integral_types.h"
35 #include "utils/base/logging.h"
36 #include "utils/flatbuffers/flatbuffers.h"
37 #include "utils/optional.h"
38 #include "utils/variant.h"
39 
40 namespace libtextclassifier3 {
41 
42 constexpr int kInvalidIndex = -1;
43 constexpr int kSunday = 1;
44 constexpr int kMonday = 2;
45 constexpr int kTuesday = 3;
46 constexpr int kWednesday = 4;
47 constexpr int kThursday = 5;
48 constexpr int kFriday = 6;
49 constexpr int kSaturday = 7;
50 
51 // Index for a 0-based array of tokens.
52 using TokenIndex = int;
53 
54 // Index for a 0-based array of codepoints.
55 using CodepointIndex = int;
56 
57 // Marks a span in a sequence of codepoints. The first element is the index of
58 // the first codepoint of the span, and the second element is the index of the
59 // codepoint one past the end of the span.
60 struct CodepointSpan {
61   static const CodepointSpan kInvalid;
62 
CodepointSpanCodepointSpan63   CodepointSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
64 
CodepointSpanCodepointSpan65   CodepointSpan(CodepointIndex start, CodepointIndex end)
66       : first(start), second(end) {}
67 
68   CodepointSpan(const CodepointSpan& other) = default;
69   CodepointSpan& operator=(const CodepointSpan& other) = default;
70 
71   bool operator==(const CodepointSpan& other) const {
72     return this->first == other.first && this->second == other.second;
73   }
74 
75   bool operator!=(const CodepointSpan& other) const {
76     return !(*this == other);
77   }
78 
79   bool operator<(const CodepointSpan& other) const {
80     if (this->first != other.first) {
81       return this->first < other.first;
82     }
83     return this->second < other.second;
84   }
85 
IsValidCodepointSpan86   bool IsValid() const {
87     return this->first != kInvalidIndex && this->second != kInvalidIndex &&
88            this->first <= this->second && this->first >= 0;
89   }
90 
IsEmptyCodepointSpan91   bool IsEmpty() const { return this->first == this->second; }
92 
93   CodepointIndex first;
94   CodepointIndex second;
95 };
96 
97 // Pretty-printing function for CodepointSpan.
98 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
99                                          const CodepointSpan& span);
100 
SpansOverlap(const CodepointSpan & a,const CodepointSpan & b)101 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
102   return a.first < b.second && b.first < a.second;
103 }
104 
SpanContains(const CodepointSpan & span,const CodepointSpan & sub_span)105 inline bool SpanContains(const CodepointSpan& span,
106                          const CodepointSpan& sub_span) {
107   return span.first <= sub_span.first && span.second >= sub_span.second;
108 }
109 
110 template <typename T>
DoesCandidateConflict(const int considered_candidate,const std::vector<T> & candidates,const std::set<int,std::function<bool (int,int)>> & chosen_indices_set)111 bool DoesCandidateConflict(
112     const int considered_candidate, const std::vector<T>& candidates,
113     const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
114   if (chosen_indices_set.empty()) {
115     return false;
116   }
117 
118   auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
119   // Check conflict on the right.
120   if (conflicting_it != chosen_indices_set.end() &&
121       SpansOverlap(candidates[considered_candidate].span,
122                    candidates[*conflicting_it].span)) {
123     return true;
124   }
125 
126   // Check conflict on the left.
127   // If we can't go more left, there can't be a conflict:
128   if (conflicting_it == chosen_indices_set.begin()) {
129     return false;
130   }
131   // Otherwise move one span left and insert if it doesn't overlap with the
132   // candidate.
133   --conflicting_it;
134   if (!SpansOverlap(candidates[considered_candidate].span,
135                     candidates[*conflicting_it].span)) {
136     return false;
137   }
138 
139   return true;
140 }
141 
142 // Marks a span in a sequence of tokens. The first element is the index of the
143 // first token in the span, and the second element is the index of the token one
144 // past the end of the span.
145 struct TokenSpan {
146   static const TokenSpan kInvalid;
147 
TokenSpanTokenSpan148   TokenSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
149 
TokenSpanTokenSpan150   TokenSpan(TokenIndex start, TokenIndex end) : first(start), second(end) {}
151 
152   // Creates a token span consisting of one token.
TokenSpanTokenSpan153   explicit TokenSpan(int token_index)
154       : first(token_index), second(token_index + 1) {}
155 
156   TokenSpan& operator=(const TokenSpan& other) = default;
157 
158   bool operator==(const TokenSpan& other) const {
159     return this->first == other.first && this->second == other.second;
160   }
161 
162   bool operator!=(const TokenSpan& other) const { return !(*this == other); }
163 
164   bool operator<(const TokenSpan& other) const {
165     if (this->first != other.first) {
166       return this->first < other.first;
167     }
168     return this->second < other.second;
169   }
170 
IsValidTokenSpan171   bool IsValid() const {
172     return this->first != kInvalidIndex && this->second != kInvalidIndex;
173   }
174 
175   // Returns the size of the token span. Assumes that the span is valid.
SizeTokenSpan176   int Size() const { return this->second - this->first; }
177 
178   // Returns an expanded token span by adding a certain number of tokens on its
179   // left and on its right.
ExpandTokenSpan180   TokenSpan Expand(int num_tokens_left, int num_tokens_right) const {
181     return {this->first - num_tokens_left, this->second + num_tokens_right};
182   }
183 
184   TokenIndex first;
185   TokenIndex second;
186 };
187 
188 // Pretty-printing function for TokenSpan.
189 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
190                                          const TokenSpan& span);
191 
192 // Returns an intersection of two token spans. Assumes that both spans are
193 // valid and overlapping.
IntersectTokenSpans(const TokenSpan & token_span1,const TokenSpan & token_span2)194 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
195                                      const TokenSpan& token_span2) {
196   return {std::max(token_span1.first, token_span2.first),
197           std::min(token_span1.second, token_span2.second)};
198 }
199 
200 // Token holds a token, its position in the original string and whether it was
201 // part of the input span.
202 struct Token {
203   std::string value;
204   CodepointIndex start;
205   CodepointIndex end;
206 
207   // Whether the token is a padding token.
208   bool is_padding;
209 
210   // Whether the token contains only white characters.
211   bool is_whitespace;
212 
213   // Default constructor constructs the padding-token.
TokenToken214   Token()
215       : Token(/*arg_value=*/"", /*arg_start=*/kInvalidIndex,
216               /*arg_end=*/kInvalidIndex, /*is_padding=*/true,
217               /*is_whitespace=*/false) {}
218 
TokenToken219   Token(const std::string& arg_value, CodepointIndex arg_start,
220         CodepointIndex arg_end)
221       : Token(/*arg_value=*/arg_value, /*arg_start=*/arg_start,
222               /*arg_end=*/arg_end, /*is_padding=*/false,
223               /*is_whitespace=*/false) {}
224 
TokenToken225   Token(const std::string& arg_value, CodepointIndex arg_start,
226         CodepointIndex arg_end, bool is_padding, bool is_whitespace)
227       : value(arg_value),
228         start(arg_start),
229         end(arg_end),
230         is_padding(is_padding),
231         is_whitespace(is_whitespace) {}
232 
233   bool operator==(const Token& other) const {
234     return value == other.value && start == other.start && end == other.end &&
235            is_padding == other.is_padding;
236   }
237 
IsContainedInSpanToken238   bool IsContainedInSpan(const CodepointSpan& span) const {
239     return start >= span.first && end <= span.second;
240   }
241 };
242 
243 // Pretty-printing function for Token.
244 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
245                                          const Token& token);
246 
247 // Returns a TokenSpan that merges all of the given token spans.
AllOf(const std::vector<Token> & tokens)248 inline TokenSpan AllOf(const std::vector<Token>& tokens) {
249   return {0, static_cast<TokenIndex>(tokens.size())};
250 }
251 
252 enum DatetimeGranularity {
253   GRANULARITY_UNKNOWN = -1,  // GRANULARITY_UNKNOWN is used as a proxy for this
254                              // structure being uninitialized.
255   GRANULARITY_YEAR = 0,
256   GRANULARITY_MONTH = 1,
257   GRANULARITY_WEEK = 2,
258   GRANULARITY_DAY = 3,
259   GRANULARITY_HOUR = 4,
260   GRANULARITY_MINUTE = 5,
261   GRANULARITY_SECOND = 6
262 };
263 
264 // This struct represents a unit of date and time expression.
265 // Examples include:
266 // - In {March 21, 2019} datetime components are month: {March},
267 //   day of month: {21} and year: {2019}.
268 // - {8:00 am} contains hour: {8}, minutes: {0} and am/pm: {am}
269 struct DatetimeComponent {
270   enum class ComponentType {
271     UNSPECIFIED = 0,
272     // Year of the date seen in the text match.
273     YEAR = 1,
274     // Month of the year starting with January = 1.
275     MONTH = 2,
276     // Week (7 days).
277     WEEK = 3,
278     // Day of week, start of the week is Sunday &  its value is 1.
279     DAY_OF_WEEK = 4,
280     // Day of the month starting with 1.
281     DAY_OF_MONTH = 5,
282     // Hour of the day with a range of 0-23,
283     // values less than 12 need the AMPM field below or heuristics
284     // to definitively determine the time.
285     HOUR = 6,
286     // Minute of the hour with a range of 0-59.
287     MINUTE = 7,
288     // Seconds of the minute with a range of 0-59.
289     SECOND = 8,
290     // Meridiem field where 0 == AM, 1 == PM.
291     MERIDIEM = 9,
292     // Offset in number of minutes from UTC this date time is in.
293     ZONE_OFFSET = 10,
294     // Offset in number of hours for DST.
295     DST_OFFSET = 11,
296   };
297 
298   // TODO(hassan): Remove RelativeQualifier as in the presence of relative
299   //               count RelativeQualifier is redundant.
300   // Enum to represent the relative DateTimeComponent e.g. "next Monday",
301   // "the following day", "tomorrow".
302   enum class RelativeQualifier {
303     UNSPECIFIED = 0,
304     NEXT = 1,
305     THIS = 2,
306     LAST = 3,
307     NOW = 4,
308     TOMORROW = 5,
309     YESTERDAY = 6,
310     PAST = 7,
311     FUTURE = 8
312   };
313 
314   bool operator==(const DatetimeComponent& other) const {
315     return component_type == other.component_type &&
316            relative_qualifier == other.relative_qualifier &&
317            relative_count == other.relative_count && value == other.value;
318   }
319 
320   bool ShouldRoundToGranularity() const;
321 
322   ComponentType component_type = ComponentType::UNSPECIFIED;
323   RelativeQualifier relative_qualifier = RelativeQualifier::UNSPECIFIED;
324 
325   // Represents the absolute value of DateTime components.
326   int value = 0;
327   // The number of units of change present in the relative DateTimeComponent.
328   int relative_count = 0;
329 
330   DatetimeComponent() = default;
331 
DatetimeComponentDatetimeComponent332   explicit DatetimeComponent(ComponentType arg_component_type,
333                              RelativeQualifier arg_relative_qualifier,
334                              int arg_value, int arg_relative_count)
335       : component_type(arg_component_type),
336         relative_qualifier(arg_relative_qualifier),
337         value(arg_value),
338         relative_count(arg_relative_count) {}
339 };
340 
341 // Utility method to calculate Returns the finest granularity of
342 // DatetimeComponents.
343 DatetimeGranularity GetFinestGranularity(
344     const std::vector<DatetimeComponent>& datetime_component);
345 
346 // Return the 'DatetimeComponent' from collection filter by component type.
347 Optional<DatetimeComponent> GetDatetimeComponent(
348     const std::vector<DatetimeComponent>& datetime_components,
349     const DatetimeComponent::ComponentType& component_type);
350 
351 struct DatetimeParseResult {
352   // The absolute time in milliseconds since the epoch in UTC.
353   int64 time_ms_utc;
354 
355   // The precision of the estimate then in to calculating the milliseconds
356   DatetimeGranularity granularity;
357 
358   // List of parsed DateTimeComponent.
359   std::vector<DatetimeComponent> datetime_components;
360 
DatetimeParseResultDatetimeParseResult361   DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
362 
DatetimeParseResultDatetimeParseResult363   DatetimeParseResult(int64 arg_time_ms_utc,
364                       DatetimeGranularity arg_granularity,
365                       std::vector<DatetimeComponent> arg_datetime__components)
366       : time_ms_utc(arg_time_ms_utc),
367         granularity(arg_granularity),
368         datetime_components(arg_datetime__components) {}
369 
IsSetDatetimeParseResult370   bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
371 
372   bool operator==(const DatetimeParseResult& other) const {
373     return granularity == other.granularity &&
374            time_ms_utc == other.time_ms_utc &&
375            datetime_components == other.datetime_components;
376   }
377 };
378 
379 const float kFloatCompareEpsilon = 1e-5;
380 
381 struct DatetimeParseResultSpan {
382   CodepointSpan span;
383   std::vector<DatetimeParseResult> data;
384   float target_classification_score;
385   float priority_score;
386 
DatetimeParseResultSpanDatetimeParseResultSpan387   DatetimeParseResultSpan()
388       : span(CodepointSpan::kInvalid),
389         target_classification_score(-1.0),
390         priority_score(-1.0) {}
391 
DatetimeParseResultSpanDatetimeParseResultSpan392   DatetimeParseResultSpan(const CodepointSpan& span,
393                           const std::vector<DatetimeParseResult>& data,
394                           const float target_classification_score,
395                           const float priority_score)
396       : span(span),
397         data(data),
398         target_classification_score(target_classification_score),
399         priority_score(priority_score) {}
400 
401   bool operator==(const DatetimeParseResultSpan& other) const {
402     return span == other.span && data == other.data &&
403            std::abs(target_classification_score -
404                     other.target_classification_score) < kFloatCompareEpsilon &&
405            std::abs(priority_score - other.priority_score) <
406                kFloatCompareEpsilon;
407   }
408 };
409 
410 // Pretty-printing function for DatetimeParseResultSpan.
411 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
412                                          const DatetimeParseResultSpan& value);
413 
414 // This struct contains information intended to uniquely identify a device
415 // contact. Instances are created by the Knowledge Engine, and dereferenced by
416 // the Contact Engine.
417 struct ContactPointer {
418   std::string focus_contact_id;
419   std::string device_id;
420   std::string device_contact_id;
421   std::string contact_name;
422   std::string contact_name_hash;
423 
424   bool operator==(const ContactPointer& other) const {
425     return focus_contact_id == other.focus_contact_id &&
426            device_id == other.device_id &&
427            device_contact_id == other.device_contact_id &&
428            contact_name == other.contact_name &&
429            contact_name_hash == other.contact_name_hash;
430   }
431 };
432 
433 struct ClassificationResult {
434   std::string collection;
435   float score;
436   DatetimeParseResult datetime_parse_result;
437   std::string serialized_knowledge_result;
438   ContactPointer contact_pointer;
439   std::string contact_name, contact_given_name, contact_family_name,
440       contact_nickname, contact_email_address, contact_phone_number,
441       contact_account_type, contact_account_name, contact_id,
442       contact_alternate_name;
443   int64 contact_recognition_source;
444   float contact_neural_match_score;
445   std::string app_name, app_package_name;
446   int64 numeric_value;
447   double numeric_double_value;
448 
449   // Length of the parsed duration in milliseconds.
450   int64 duration_ms;
451 
452   // Internal score used for conflict resolution.
453   float priority_score;
454 
455 
456   // Entity data information.
457   std::string serialized_entity_data;
entity_dataClassificationResult458   const EntityData* entity_data() const {
459     return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
460                                                serialized_entity_data.size());
461   }
462 
ClassificationResultClassificationResult463   explicit ClassificationResult()
464       : score(-1.0f),
465         numeric_value(0),
466         numeric_double_value(0.),
467         duration_ms(0),
468         priority_score(-1.0) {}
469 
ClassificationResultClassificationResult470   ClassificationResult(const std::string& arg_collection, float arg_score)
471       : collection(arg_collection),
472         score(arg_score),
473         numeric_value(0),
474         numeric_double_value(0.),
475         duration_ms(0),
476         priority_score(arg_score) {}
477 
ClassificationResultClassificationResult478   ClassificationResult(const std::string& arg_collection, float arg_score,
479                        float arg_priority_score)
480       : collection(arg_collection),
481         score(arg_score),
482         numeric_value(0),
483         numeric_double_value(0.),
484         duration_ms(0),
485         priority_score(arg_priority_score) {}
486 
487   bool operator!=(const ClassificationResult& other) const {
488     return !(*this == other);
489   }
490 
491   bool operator==(const ClassificationResult& other) const;
492 };
493 
494 // Aliases for long enum values.
495 const AnnotationUsecase ANNOTATION_USECASE_SMART =
496     AnnotationUsecase_ANNOTATION_USECASE_SMART;
497 const AnnotationUsecase ANNOTATION_USECASE_RAW =
498     AnnotationUsecase_ANNOTATION_USECASE_RAW;
499 
500 struct LocationContext {
501   // User location latitude in degrees.
502   double user_location_lat = 180.;
503 
504   // User location longitude in degrees.
505   double user_location_lng = 360.;
506 
507   // The estimated horizontal accuracy of the user location in meters.
508   // Analogous to android.location.Location accuracy.
509   float user_location_accuracy_meters = 0.f;
510 
511   bool operator==(const LocationContext& other) const {
512     return std::fabs(this->user_location_lat - other.user_location_lat) <
513                1e-8 &&
514            std::fabs(this->user_location_lng - other.user_location_lng) <
515                1e-8 &&
516            std::fabs(this->user_location_accuracy_meters -
517                      other.user_location_accuracy_meters) < 1e-8;
518   }
519 };
520 
521 struct BaseOptions {
522   // Comma-separated list of locale specification for the input text (BCP 47
523   // tags).
524   std::string locales;
525 
526   // Comma-separated list of BCP 47 language tags.
527   std::string detected_text_language_tags;
528 
529   // Tailors the output annotations according to the specified use-case.
530   AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
531 
532   // The location context passed along with each annotation.
533   Optional<LocationContext> location_context;
534 
535   // If true, the POD NER annotator is used.
536   bool use_pod_ner = true;
537 
538   // If true and the model file supports that, the new vocab annotator is used
539   // to annotate "Dictionary". Otherwise, we use the FFModel to do so.
540   bool use_vocab_annotator = true;
541 
542   bool operator==(const BaseOptions& other) const {
543     bool location_context_equality = this->location_context.has_value() ==
544                                      other.location_context.has_value();
545     if (this->location_context.has_value() &&
546         other.location_context.has_value()) {
547       location_context_equality =
548           this->location_context.value() == other.location_context.value();
549     }
550     return this->locales == other.locales &&
551            this->annotation_usecase == other.annotation_usecase &&
552            this->detected_text_language_tags ==
553                other.detected_text_language_tags &&
554            location_context_equality &&
555            this->use_pod_ner == other.use_pod_ner &&
556            this->use_vocab_annotator == other.use_vocab_annotator;
557   }
558 };
559 
560 struct DatetimeOptions {
561   // For parsing relative datetimes, the reference now time against which the
562   // relative datetimes get resolved.
563   // UTC milliseconds since epoch.
564   int64 reference_time_ms_utc = 0;
565 
566   // Timezone in which the input text was written (format as accepted by ICU).
567   std::string reference_timezone;
568 
569   bool operator==(const DatetimeOptions& other) const {
570     return this->reference_time_ms_utc == other.reference_time_ms_utc &&
571            this->reference_timezone == other.reference_timezone;
572   }
573 };
574 
575 struct SelectionOptions : public BaseOptions {};
576 
577 struct ClassificationOptions : public BaseOptions, public DatetimeOptions {
578   // Comma-separated list of language tags which the user can read and
579   // understand (BCP 47).
580   std::string user_familiar_language_tags;
581   // If true, trigger dictionary on words that are of beginner level.
582   bool trigger_dictionary_on_beginner_words = false;
583   // If true, generate *Add* contact intent for email/phone entity.
584   bool enable_add_contact_intent;
585   // If true, generate *Search* intent for named entities.
586   bool enable_search_intent;
587 
588   bool operator==(const ClassificationOptions& other) const {
589     return this->user_familiar_language_tags ==
590                other.user_familiar_language_tags &&
591            this->trigger_dictionary_on_beginner_words ==
592                other.trigger_dictionary_on_beginner_words &&
593            this->enable_add_contact_intent == other.enable_add_contact_intent &&
594            this->enable_search_intent == other.enable_search_intent &&
595            BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
596   }
597 };
598 
599 struct Permissions {
600   // If true the user location can be used to provide better annotations.
601   bool has_location_permission = true;
602   // If true, annotators can use personal data to provide personalized
603   // annotations.
604   bool has_personalization_permission = true;
605 
606   bool operator==(const Permissions& other) const {
607     return this->has_location_permission == other.has_location_permission &&
608            this->has_personalization_permission ==
609                other.has_personalization_permission;
610   }
611 };
612 
613 struct AnnotationOptions : public BaseOptions, public DatetimeOptions {
614   // List of entity types that should be used for annotation.
615   std::unordered_set<std::string> entity_types;
616 
617   // If true, serialized_entity_data in the results is populated."
618   bool is_serialized_entity_data_enabled = false;
619 
620   // Defines the permissions for the annotators.
621   Permissions permissions;
622 
623   AnnotateMode annotate_mode = AnnotateMode::kEntityAnnotation;
624 
625   // If true, trigger dictionary on words that are of beginner level.
626   bool trigger_dictionary_on_beginner_words = false;
627 
628   bool operator==(const AnnotationOptions& other) const {
629     return this->is_serialized_entity_data_enabled ==
630                other.is_serialized_entity_data_enabled &&
631            this->permissions == other.permissions &&
632            this->entity_types == other.entity_types &&
633            this->annotate_mode == other.annotate_mode &&
634            this->trigger_dictionary_on_beginner_words ==
635                other.trigger_dictionary_on_beginner_words &&
636            BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
637   }
638 };
639 
640 // Returns true when ClassificationResults are euqal up to scores.
641 bool ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
642     const ClassificationResult& a, const ClassificationResult& b);
643 
644 // Pretty-printing function for ClassificationResult.
645 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
646                                          const ClassificationResult& result);
647 
648 // Pretty-printing function for std::vector<ClassificationResult>.
649 logging::LoggingStringStream& operator<<(
650     logging::LoggingStringStream& stream,
651     const std::vector<ClassificationResult>& results);
652 
653 // Represents a result of Annotate call.
654 struct AnnotatedSpan {
655   enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME, PERSON_NAME };
656 
657   // Unicode codepoint indices in the input string.
658   CodepointSpan span = CodepointSpan::kInvalid;
659 
660   // Classification result for the span.
661   std::vector<ClassificationResult> classification;
662 
663   // The source of the annotation, used in conflict resolution.
664   Source source = Source::OTHER;
665 
666   AnnotatedSpan() = default;
667 
AnnotatedSpanAnnotatedSpan668   AnnotatedSpan(CodepointSpan arg_span,
669                 std::vector<ClassificationResult> arg_classification)
670       : span(arg_span), classification(std::move(arg_classification)) {}
671 
AnnotatedSpanAnnotatedSpan672   AnnotatedSpan(CodepointSpan arg_span,
673                 std::vector<ClassificationResult> arg_classification,
674                 Source arg_source)
675       : span(arg_span),
676         classification(std::move(arg_classification)),
677         source(arg_source) {}
678 };
679 
680 // Represents Annotations that correspond to all input fragments.
681 struct Annotations {
682   // List of annotations found in the corresponding input fragments. For these
683   // annotations, topicality score will not be set.
684   std::vector<std::vector<AnnotatedSpan>> annotated_spans;
685 
686   // List of topicality results found across all input fragments.
687   std::vector<ClassificationResult> topicality_results;
688 
689   Annotations() = default;
690 
AnnotationsAnnotations691   explicit Annotations(
692       std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans)
693       : annotated_spans(std::move(arg_annotated_spans)) {}
694 
AnnotationsAnnotations695   Annotations(std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans,
696               std::vector<ClassificationResult> arg_topicality_results)
697       : annotated_spans(std::move(arg_annotated_spans)),
698         topicality_results(std::move(arg_topicality_results)) {}
699 };
700 
701 struct InputFragment {
702   std::string text;
703   float bounding_box_top;
704   float bounding_box_height;
705 
706   // If present will override the AnnotationOptions reference time and timezone
707   // when annotating this specific string fragment.
708   Optional<DatetimeOptions> datetime_options;
709 };
710 
711 // Pretty-printing function for AnnotatedSpan.
712 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
713                                          const AnnotatedSpan& span);
714 
715 // StringPiece analogue for std::vector<T>.
716 template <class T>
717 class VectorSpan {
718  public:
VectorSpan()719   VectorSpan() : begin_(), end_() {}
VectorSpan(const std::vector<T> & v)720   explicit VectorSpan(const std::vector<T>& v)  // NOLINT(runtime/explicit)
721       : begin_(v.begin()), end_(v.end()) {}
VectorSpan(typename std::vector<T>::const_iterator begin,typename std::vector<T>::const_iterator end)722   VectorSpan(typename std::vector<T>::const_iterator begin,
723              typename std::vector<T>::const_iterator end)
724       : begin_(begin), end_(end) {}
725 
726   const T& operator[](typename std::vector<T>::size_type i) const {
727     return *(begin_ + i);
728   }
729 
size()730   int size() const { return end_ - begin_; }
begin()731   typename std::vector<T>::const_iterator begin() const { return begin_; }
end()732   typename std::vector<T>::const_iterator end() const { return end_; }
data()733   const float* data() const { return &(*begin_); }
734 
735  private:
736   typename std::vector<T>::const_iterator begin_;
737   typename std::vector<T>::const_iterator end_;
738 };
739 
740 // Class to provide representation of date and time expressions
741 class DatetimeParsedData {
742  public:
743   // Function to set the absolute value of DateTimeComponent for the given
744   // FieldType, if the field is not present it will create the field and set
745   // the value.
746   void SetAbsoluteValue(const DatetimeComponent::ComponentType& field_type,
747                         int value);
748 
749   // Function to set the relative value of DateTimeComponent, if the field is
750   // not present the function will create the field and set the relative value.
751   void SetRelativeValue(
752       const DatetimeComponent::ComponentType& field_type,
753       const DatetimeComponent::RelativeQualifier& relative_value);
754 
755   // Add collection of 'DatetimeComponent' to 'DatetimeParsedData'.
756   void AddDatetimeComponents(
757       const std::vector<DatetimeComponent>& datetime_components);
758 
759   // Function to set the relative count of DateTimeComponent, if the field is
760   // not present the function will create the field and set the count.
761   void SetRelativeCount(const DatetimeComponent::ComponentType& field_type,
762                         int relative_count);
763 
764   // Function to populate the absolute value of the FieldType and return true.
765   // In case of no FieldType function will return false.
766   bool GetFieldValue(const DatetimeComponent::ComponentType& field_type,
767                      int* field_value) const;
768 
769   // Function to populate the relative value of the FieldType and return true.
770   // In case of no relative value function will return false.
771   bool GetRelativeValue(
772       const DatetimeComponent::ComponentType& field_type,
773       DatetimeComponent::RelativeQualifier* relative_value) const;
774 
775   // Returns relative DateTimeComponent from the parsed DateTime span.
776   void GetRelativeDatetimeComponents(
777       std::vector<DatetimeComponent>* date_time_components) const;
778 
779   // Returns DateTimeComponent from the parsed DateTime span.
780   void GetDatetimeComponents(
781       std::vector<DatetimeComponent>* date_time_components) const;
782 
783   // Represent the granularity of the Parsed DateTime span. The function will
784   // return “GRANULARITY_UNKNOWN” if no datetime field is set.
785   DatetimeGranularity GetFinestGranularity() const;
786 
787   // Utility function to check if DateTimeParsedData has FieldType initialized.
788   bool HasFieldType(const DatetimeComponent::ComponentType& field_type) const;
789 
790   // Function to check if DateTimeParsedData has relative DateTimeComponent for
791   // given FieldType.
792   bool HasRelativeValue(
793       const DatetimeComponent::ComponentType& field_type) const;
794 
795   // Function to check if DateTimeParsedData has absolute value
796   // DateTimeComponent for given FieldType.
797   bool HasAbsoluteValue(
798       const DatetimeComponent::ComponentType& field_type) const;
799 
800   // Function to check if DateTimeParsedData has any DateTimeComponent.
801   bool IsEmpty() const;
802 
803  private:
804   DatetimeComponent& GetOrCreateDatetimeComponent(
805 
806       const DatetimeComponent::ComponentType& component_type);
807 
808   std::map<DatetimeComponent::ComponentType, DatetimeComponent>
809       date_time_components_;
810 };
811 
812 // Pretty-printing function for DateTimeParsedData.
813 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
814                                          const DatetimeParsedData& data);
815 
816 }  // namespace libtextclassifier3
817 
818 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
819