1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
18 #define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
19
20 #include <time.h>
21
22 #include <algorithm>
23 #include <cmath>
24 #include <functional>
25 #include <map>
26 #include <set>
27 #include <string>
28 #include <unordered_set>
29 #include <utility>
30 #include <vector>
31
32 #include "annotator/entity-data_generated.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "utils/base/integral_types.h"
35 #include "utils/base/logging.h"
36 #include "utils/flatbuffers/flatbuffers.h"
37 #include "utils/optional.h"
38 #include "utils/variant.h"
39
40 namespace libtextclassifier3 {
41
42 constexpr int kInvalidIndex = -1;
43 constexpr int kSunday = 1;
44 constexpr int kMonday = 2;
45 constexpr int kTuesday = 3;
46 constexpr int kWednesday = 4;
47 constexpr int kThursday = 5;
48 constexpr int kFriday = 6;
49 constexpr int kSaturday = 7;
50
51 // Index for a 0-based array of tokens.
52 using TokenIndex = int;
53
54 // Index for a 0-based array of codepoints.
55 using CodepointIndex = int;
56
57 // Marks a span in a sequence of codepoints. The first element is the index of
58 // the first codepoint of the span, and the second element is the index of the
59 // codepoint one past the end of the span.
60 struct CodepointSpan {
61 static const CodepointSpan kInvalid;
62
CodepointSpanCodepointSpan63 CodepointSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
64
CodepointSpanCodepointSpan65 CodepointSpan(CodepointIndex start, CodepointIndex end)
66 : first(start), second(end) {}
67
68 CodepointSpan& operator=(const CodepointSpan& other) = default;
69
70 bool operator==(const CodepointSpan& other) const {
71 return this->first == other.first && this->second == other.second;
72 }
73
74 bool operator!=(const CodepointSpan& other) const {
75 return !(*this == other);
76 }
77
78 bool operator<(const CodepointSpan& other) const {
79 if (this->first != other.first) {
80 return this->first < other.first;
81 }
82 return this->second < other.second;
83 }
84
IsValidCodepointSpan85 bool IsValid() const {
86 return this->first != kInvalidIndex && this->second != kInvalidIndex &&
87 this->first <= this->second && this->first >= 0;
88 }
89
IsEmptyCodepointSpan90 bool IsEmpty() const { return this->first == this->second; }
91
92 CodepointIndex first;
93 CodepointIndex second;
94 };
95
96 // Pretty-printing function for CodepointSpan.
97 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
98 const CodepointSpan& span);
99
SpansOverlap(const CodepointSpan & a,const CodepointSpan & b)100 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
101 return a.first < b.second && b.first < a.second;
102 }
103
SpanContains(const CodepointSpan & span,const CodepointSpan & sub_span)104 inline bool SpanContains(const CodepointSpan& span,
105 const CodepointSpan& sub_span) {
106 return span.first <= sub_span.first && span.second >= sub_span.second;
107 }
108
109 template <typename T>
DoesCandidateConflict(const int considered_candidate,const std::vector<T> & candidates,const std::set<int,std::function<bool (int,int)>> & chosen_indices_set)110 bool DoesCandidateConflict(
111 const int considered_candidate, const std::vector<T>& candidates,
112 const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
113 if (chosen_indices_set.empty()) {
114 return false;
115 }
116
117 auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
118 // Check conflict on the right.
119 if (conflicting_it != chosen_indices_set.end() &&
120 SpansOverlap(candidates[considered_candidate].span,
121 candidates[*conflicting_it].span)) {
122 return true;
123 }
124
125 // Check conflict on the left.
126 // If we can't go more left, there can't be a conflict:
127 if (conflicting_it == chosen_indices_set.begin()) {
128 return false;
129 }
130 // Otherwise move one span left and insert if it doesn't overlap with the
131 // candidate.
132 --conflicting_it;
133 if (!SpansOverlap(candidates[considered_candidate].span,
134 candidates[*conflicting_it].span)) {
135 return false;
136 }
137
138 return true;
139 }
140
141 // Marks a span in a sequence of tokens. The first element is the index of the
142 // first token in the span, and the second element is the index of the token one
143 // past the end of the span.
144 struct TokenSpan {
145 static const TokenSpan kInvalid;
146
TokenSpanTokenSpan147 TokenSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
148
TokenSpanTokenSpan149 TokenSpan(TokenIndex start, TokenIndex end) : first(start), second(end) {}
150
151 // Creates a token span consisting of one token.
TokenSpanTokenSpan152 explicit TokenSpan(int token_index)
153 : first(token_index), second(token_index + 1) {}
154
155 TokenSpan& operator=(const TokenSpan& other) = default;
156
157 bool operator==(const TokenSpan& other) const {
158 return this->first == other.first && this->second == other.second;
159 }
160
161 bool operator!=(const TokenSpan& other) const { return !(*this == other); }
162
163 bool operator<(const TokenSpan& other) const {
164 if (this->first != other.first) {
165 return this->first < other.first;
166 }
167 return this->second < other.second;
168 }
169
IsValidTokenSpan170 bool IsValid() const {
171 return this->first != kInvalidIndex && this->second != kInvalidIndex;
172 }
173
174 // Returns the size of the token span. Assumes that the span is valid.
SizeTokenSpan175 int Size() const { return this->second - this->first; }
176
177 // Returns an expanded token span by adding a certain number of tokens on its
178 // left and on its right.
ExpandTokenSpan179 TokenSpan Expand(int num_tokens_left, int num_tokens_right) const {
180 return {this->first - num_tokens_left, this->second + num_tokens_right};
181 }
182
183 TokenIndex first;
184 TokenIndex second;
185 };
186
187 // Pretty-printing function for TokenSpan.
188 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
189 const TokenSpan& span);
190
191 // Returns an intersection of two token spans. Assumes that both spans are
192 // valid and overlapping.
IntersectTokenSpans(const TokenSpan & token_span1,const TokenSpan & token_span2)193 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
194 const TokenSpan& token_span2) {
195 return {std::max(token_span1.first, token_span2.first),
196 std::min(token_span1.second, token_span2.second)};
197 }
198
199 // Token holds a token, its position in the original string and whether it was
200 // part of the input span.
201 struct Token {
202 std::string value;
203 CodepointIndex start;
204 CodepointIndex end;
205
206 // Whether the token is a padding token.
207 bool is_padding;
208
209 // Whether the token contains only white characters.
210 bool is_whitespace;
211
212 // Default constructor constructs the padding-token.
TokenToken213 Token()
214 : Token(/*arg_value=*/"", /*arg_start=*/kInvalidIndex,
215 /*arg_end=*/kInvalidIndex, /*is_padding=*/true,
216 /*is_whitespace=*/false) {}
217
TokenToken218 Token(const std::string& arg_value, CodepointIndex arg_start,
219 CodepointIndex arg_end)
220 : Token(/*arg_value=*/arg_value, /*arg_start=*/arg_start,
221 /*arg_end=*/arg_end, /*is_padding=*/false,
222 /*is_whitespace=*/false) {}
223
TokenToken224 Token(const std::string& arg_value, CodepointIndex arg_start,
225 CodepointIndex arg_end, bool is_padding, bool is_whitespace)
226 : value(arg_value),
227 start(arg_start),
228 end(arg_end),
229 is_padding(is_padding),
230 is_whitespace(is_whitespace) {}
231
232 bool operator==(const Token& other) const {
233 return value == other.value && start == other.start && end == other.end &&
234 is_padding == other.is_padding;
235 }
236
IsContainedInSpanToken237 bool IsContainedInSpan(const CodepointSpan& span) const {
238 return start >= span.first && end <= span.second;
239 }
240 };
241
242 // Pretty-printing function for Token.
243 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
244 const Token& token);
245
246 // Returns a TokenSpan that merges all of the given token spans.
AllOf(const std::vector<Token> & tokens)247 inline TokenSpan AllOf(const std::vector<Token>& tokens) {
248 return {0, static_cast<TokenIndex>(tokens.size())};
249 }
250
251 enum DatetimeGranularity {
252 GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
253 // structure being uninitialized.
254 GRANULARITY_YEAR = 0,
255 GRANULARITY_MONTH = 1,
256 GRANULARITY_WEEK = 2,
257 GRANULARITY_DAY = 3,
258 GRANULARITY_HOUR = 4,
259 GRANULARITY_MINUTE = 5,
260 GRANULARITY_SECOND = 6
261 };
262
263 // This struct represents a unit of date and time expression.
264 // Examples include:
265 // - In {March 21, 2019} datetime components are month: {March},
266 // day of month: {21} and year: {2019}.
267 // - {8:00 am} contains hour: {8}, minutes: {0} and am/pm: {am}
268 struct DatetimeComponent {
269 enum class ComponentType {
270 UNSPECIFIED = 0,
271 // Year of the date seen in the text match.
272 YEAR = 1,
273 // Month of the year starting with January = 1.
274 MONTH = 2,
275 // Week (7 days).
276 WEEK = 3,
277 // Day of week, start of the week is Sunday & its value is 1.
278 DAY_OF_WEEK = 4,
279 // Day of the month starting with 1.
280 DAY_OF_MONTH = 5,
281 // Hour of the day with a range of 0-23,
282 // values less than 12 need the AMPM field below or heuristics
283 // to definitively determine the time.
284 HOUR = 6,
285 // Minute of the hour with a range of 0-59.
286 MINUTE = 7,
287 // Seconds of the minute with a range of 0-59.
288 SECOND = 8,
289 // Meridiem field where 0 == AM, 1 == PM.
290 MERIDIEM = 9,
291 // Offset in number of minutes from UTC this date time is in.
292 ZONE_OFFSET = 10,
293 // Offset in number of hours for DST.
294 DST_OFFSET = 11,
295 };
296
297 // TODO(hassan): Remove RelativeQualifier as in the presence of relative
298 // count RelativeQualifier is redundant.
299 // Enum to represent the relative DateTimeComponent e.g. "next Monday",
300 // "the following day", "tomorrow".
301 enum class RelativeQualifier {
302 UNSPECIFIED = 0,
303 NEXT = 1,
304 THIS = 2,
305 LAST = 3,
306 NOW = 4,
307 TOMORROW = 5,
308 YESTERDAY = 6,
309 PAST = 7,
310 FUTURE = 8
311 };
312
313 bool operator==(const DatetimeComponent& other) const {
314 return component_type == other.component_type &&
315 relative_qualifier == other.relative_qualifier &&
316 relative_count == other.relative_count && value == other.value;
317 }
318
319 bool ShouldRoundToGranularity() const;
320
321 ComponentType component_type = ComponentType::UNSPECIFIED;
322 RelativeQualifier relative_qualifier = RelativeQualifier::UNSPECIFIED;
323
324 // Represents the absolute value of DateTime components.
325 int value = 0;
326 // The number of units of change present in the relative DateTimeComponent.
327 int relative_count = 0;
328
329 DatetimeComponent() = default;
330
DatetimeComponentDatetimeComponent331 explicit DatetimeComponent(ComponentType arg_component_type,
332 RelativeQualifier arg_relative_qualifier,
333 int arg_value, int arg_relative_count)
334 : component_type(arg_component_type),
335 relative_qualifier(arg_relative_qualifier),
336 value(arg_value),
337 relative_count(arg_relative_count) {}
338 };
339
340 // Utility method to calculate Returns the finest granularity of
341 // DatetimeComponents.
342 DatetimeGranularity GetFinestGranularity(
343 const std::vector<DatetimeComponent>& datetime_component);
344
345 // Return the 'DatetimeComponent' from collection filter by component type.
346 Optional<DatetimeComponent> GetDatetimeComponent(
347 const std::vector<DatetimeComponent>& datetime_components,
348 const DatetimeComponent::ComponentType& component_type);
349
350 struct DatetimeParseResult {
351 // The absolute time in milliseconds since the epoch in UTC.
352 int64 time_ms_utc;
353
354 // The precision of the estimate then in to calculating the milliseconds
355 DatetimeGranularity granularity;
356
357 // List of parsed DateTimeComponent.
358 std::vector<DatetimeComponent> datetime_components;
359
DatetimeParseResultDatetimeParseResult360 DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
361
DatetimeParseResultDatetimeParseResult362 DatetimeParseResult(int64 arg_time_ms_utc,
363 DatetimeGranularity arg_granularity,
364 std::vector<DatetimeComponent> arg_datetime__components)
365 : time_ms_utc(arg_time_ms_utc),
366 granularity(arg_granularity),
367 datetime_components(arg_datetime__components) {}
368
IsSetDatetimeParseResult369 bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
370
371 bool operator==(const DatetimeParseResult& other) const {
372 return granularity == other.granularity &&
373 time_ms_utc == other.time_ms_utc &&
374 datetime_components == other.datetime_components;
375 }
376 };
377
378 const float kFloatCompareEpsilon = 1e-5;
379
380 struct DatetimeParseResultSpan {
381 CodepointSpan span;
382 std::vector<DatetimeParseResult> data;
383 float target_classification_score;
384 float priority_score;
385
DatetimeParseResultSpanDatetimeParseResultSpan386 DatetimeParseResultSpan()
387 : span(CodepointSpan::kInvalid),
388 target_classification_score(-1.0),
389 priority_score(-1.0) {}
390
DatetimeParseResultSpanDatetimeParseResultSpan391 DatetimeParseResultSpan(const CodepointSpan& span,
392 const std::vector<DatetimeParseResult>& data,
393 const float target_classification_score,
394 const float priority_score)
395 : span(span),
396 data(data),
397 target_classification_score(target_classification_score),
398 priority_score(priority_score) {}
399
400 bool operator==(const DatetimeParseResultSpan& other) const {
401 return span == other.span && data == other.data &&
402 std::abs(target_classification_score -
403 other.target_classification_score) < kFloatCompareEpsilon &&
404 std::abs(priority_score - other.priority_score) <
405 kFloatCompareEpsilon;
406 }
407 };
408
409 // Pretty-printing function for DatetimeParseResultSpan.
410 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
411 const DatetimeParseResultSpan& value);
412
413 // This struct contains information intended to uniquely identify a device
414 // contact. Instances are created by the Knowledge Engine, and dereferenced by
415 // the Contact Engine.
416 struct ContactPointer {
417 std::string focus_contact_id;
418 std::string device_id;
419 std::string device_contact_id;
420 std::string contact_name;
421 std::string contact_name_hash;
422
423 bool operator==(const ContactPointer& other) const {
424 return focus_contact_id == other.focus_contact_id &&
425 device_id == other.device_id &&
426 device_contact_id == other.device_contact_id &&
427 contact_name == other.contact_name &&
428 contact_name_hash == other.contact_name_hash;
429 }
430 };
431
432 struct ClassificationResult {
433 std::string collection;
434 float score;
435 DatetimeParseResult datetime_parse_result;
436 std::string serialized_knowledge_result;
437 ContactPointer contact_pointer;
438 std::string contact_name, contact_given_name, contact_family_name,
439 contact_nickname, contact_email_address, contact_phone_number,
440 contact_account_type, contact_account_name, contact_id,
441 contact_alternate_name;
442 std::string app_name, app_package_name;
443 int64 numeric_value;
444 double numeric_double_value;
445
446 // Length of the parsed duration in milliseconds.
447 int64 duration_ms;
448
449 // Internal score used for conflict resolution.
450 float priority_score;
451
452
453 // Entity data information.
454 std::string serialized_entity_data;
entity_dataClassificationResult455 const EntityData* entity_data() const {
456 return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
457 serialized_entity_data.size());
458 }
459
ClassificationResultClassificationResult460 explicit ClassificationResult()
461 : score(-1.0f),
462 numeric_value(0),
463 numeric_double_value(0.),
464 duration_ms(0),
465 priority_score(-1.0) {}
466
ClassificationResultClassificationResult467 ClassificationResult(const std::string& arg_collection, float arg_score)
468 : collection(arg_collection),
469 score(arg_score),
470 numeric_value(0),
471 numeric_double_value(0.),
472 duration_ms(0),
473 priority_score(arg_score) {}
474
ClassificationResultClassificationResult475 ClassificationResult(const std::string& arg_collection, float arg_score,
476 float arg_priority_score)
477 : collection(arg_collection),
478 score(arg_score),
479 numeric_value(0),
480 numeric_double_value(0.),
481 duration_ms(0),
482 priority_score(arg_priority_score) {}
483
484 bool operator!=(const ClassificationResult& other) const {
485 return !(*this == other);
486 }
487
488 bool operator==(const ClassificationResult& other) const;
489 };
490
491 // Aliases for long enum values.
492 const AnnotationUsecase ANNOTATION_USECASE_SMART =
493 AnnotationUsecase_ANNOTATION_USECASE_SMART;
494 const AnnotationUsecase ANNOTATION_USECASE_RAW =
495 AnnotationUsecase_ANNOTATION_USECASE_RAW;
496
497 struct LocationContext {
498 // User location latitude in degrees.
499 double user_location_lat = 180.;
500
501 // User location longitude in degrees.
502 double user_location_lng = 360.;
503
504 // The estimated horizontal accuracy of the user location in meters.
505 // Analogous to android.location.Location accuracy.
506 float user_location_accuracy_meters = 0.f;
507
508 bool operator==(const LocationContext& other) const {
509 return std::fabs(this->user_location_lat - other.user_location_lat) <
510 1e-8 &&
511 std::fabs(this->user_location_lng - other.user_location_lng) <
512 1e-8 &&
513 std::fabs(this->user_location_accuracy_meters -
514 other.user_location_accuracy_meters) < 1e-8;
515 }
516 };
517
518 struct BaseOptions {
519 // Comma-separated list of locale specification for the input text (BCP 47
520 // tags).
521 std::string locales;
522
523 // Comma-separated list of BCP 47 language tags.
524 std::string detected_text_language_tags;
525
526 // Tailors the output annotations according to the specified use-case.
527 AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
528
529 // The location context passed along with each annotation.
530 Optional<LocationContext> location_context;
531
532 // If true, the POD NER annotator is used.
533 bool use_pod_ner = true;
534
535 // If true and the model file supports that, the new vocab annotator is used
536 // to annotate "Dictionary". Otherwise, we use the FFModel to do so.
537 bool use_vocab_annotator = true;
538
539 bool operator==(const BaseOptions& other) const {
540 bool location_context_equality = this->location_context.has_value() ==
541 other.location_context.has_value();
542 if (this->location_context.has_value() &&
543 other.location_context.has_value()) {
544 location_context_equality =
545 this->location_context.value() == other.location_context.value();
546 }
547 return this->locales == other.locales &&
548 this->annotation_usecase == other.annotation_usecase &&
549 this->detected_text_language_tags ==
550 other.detected_text_language_tags &&
551 location_context_equality &&
552 this->use_pod_ner == other.use_pod_ner &&
553 this->use_vocab_annotator == other.use_vocab_annotator;
554 }
555 };
556
557 struct DatetimeOptions {
558 // For parsing relative datetimes, the reference now time against which the
559 // relative datetimes get resolved.
560 // UTC milliseconds since epoch.
561 int64 reference_time_ms_utc = 0;
562
563 // Timezone in which the input text was written (format as accepted by ICU).
564 std::string reference_timezone;
565
566 bool operator==(const DatetimeOptions& other) const {
567 return this->reference_time_ms_utc == other.reference_time_ms_utc &&
568 this->reference_timezone == other.reference_timezone;
569 }
570 };
571
572 struct SelectionOptions : public BaseOptions {};
573
574 struct ClassificationOptions : public BaseOptions, public DatetimeOptions {
575 // Comma-separated list of language tags which the user can read and
576 // understand (BCP 47).
577 std::string user_familiar_language_tags;
578 // If true, trigger dictionary on words that are of beginner level.
579 bool trigger_dictionary_on_beginner_words = false;
580
581 bool operator==(const ClassificationOptions& other) const {
582 return this->user_familiar_language_tags ==
583 other.user_familiar_language_tags &&
584 this->trigger_dictionary_on_beginner_words ==
585 other.trigger_dictionary_on_beginner_words &&
586 BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
587 }
588 };
589
590 struct Permissions {
591 // If true the user location can be used to provide better annotations.
592 bool has_location_permission = true;
593 // If true, annotators can use personal data to provide personalized
594 // annotations.
595 bool has_personalization_permission = true;
596
597 bool operator==(const Permissions& other) const {
598 return this->has_location_permission == other.has_location_permission &&
599 this->has_personalization_permission ==
600 other.has_personalization_permission;
601 }
602 };
603
604 struct AnnotationOptions : public BaseOptions, public DatetimeOptions {
605 // List of entity types that should be used for annotation.
606 std::unordered_set<std::string> entity_types;
607
608 // If true, serialized_entity_data in the results is populated."
609 bool is_serialized_entity_data_enabled = false;
610
611 // Defines the permissions for the annotators.
612 Permissions permissions;
613
614 AnnotateMode annotate_mode = AnnotateMode::kEntityAnnotation;
615
616 // If true, trigger dictionary on words that are of beginner level.
617 bool trigger_dictionary_on_beginner_words = false;
618
619 bool operator==(const AnnotationOptions& other) const {
620 return this->is_serialized_entity_data_enabled ==
621 other.is_serialized_entity_data_enabled &&
622 this->permissions == other.permissions &&
623 this->entity_types == other.entity_types &&
624 this->annotate_mode == other.annotate_mode &&
625 this->trigger_dictionary_on_beginner_words ==
626 other.trigger_dictionary_on_beginner_words &&
627 BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
628 }
629 };
630
631 // Returns true when ClassificationResults are euqal up to scores.
632 bool ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
633 const ClassificationResult& a, const ClassificationResult& b);
634
635 // Pretty-printing function for ClassificationResult.
636 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
637 const ClassificationResult& result);
638
639 // Pretty-printing function for std::vector<ClassificationResult>.
640 logging::LoggingStringStream& operator<<(
641 logging::LoggingStringStream& stream,
642 const std::vector<ClassificationResult>& results);
643
644 // Represents a result of Annotate call.
645 struct AnnotatedSpan {
646 enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME, PERSON_NAME };
647
648 // Unicode codepoint indices in the input string.
649 CodepointSpan span = CodepointSpan::kInvalid;
650
651 // Classification result for the span.
652 std::vector<ClassificationResult> classification;
653
654 // The source of the annotation, used in conflict resolution.
655 Source source = Source::OTHER;
656
657 AnnotatedSpan() = default;
658
AnnotatedSpanAnnotatedSpan659 AnnotatedSpan(CodepointSpan arg_span,
660 std::vector<ClassificationResult> arg_classification)
661 : span(arg_span), classification(std::move(arg_classification)) {}
662
AnnotatedSpanAnnotatedSpan663 AnnotatedSpan(CodepointSpan arg_span,
664 std::vector<ClassificationResult> arg_classification,
665 Source arg_source)
666 : span(arg_span),
667 classification(std::move(arg_classification)),
668 source(arg_source) {}
669 };
670
671 // Represents Annotations that correspond to all input fragments.
672 struct Annotations {
673 // List of annotations found in the corresponding input fragments. For these
674 // annotations, topicality score will not be set.
675 std::vector<std::vector<AnnotatedSpan>> annotated_spans;
676
677 // List of topicality results found across all input fragments.
678 std::vector<ClassificationResult> topicality_results;
679
680 Annotations() = default;
681
AnnotationsAnnotations682 explicit Annotations(
683 std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans)
684 : annotated_spans(std::move(arg_annotated_spans)) {}
685
AnnotationsAnnotations686 Annotations(std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans,
687 std::vector<ClassificationResult> arg_topicality_results)
688 : annotated_spans(std::move(arg_annotated_spans)),
689 topicality_results(std::move(arg_topicality_results)) {}
690 };
691
692 struct InputFragment {
693 std::string text;
694 float bounding_box_top;
695 float bounding_box_height;
696
697 // If present will override the AnnotationOptions reference time and timezone
698 // when annotating this specific string fragment.
699 Optional<DatetimeOptions> datetime_options;
700 };
701
702 // Pretty-printing function for AnnotatedSpan.
703 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
704 const AnnotatedSpan& span);
705
706 // StringPiece analogue for std::vector<T>.
707 template <class T>
708 class VectorSpan {
709 public:
VectorSpan()710 VectorSpan() : begin_(), end_() {}
VectorSpan(const std::vector<T> & v)711 explicit VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
712 : begin_(v.begin()), end_(v.end()) {}
VectorSpan(typename std::vector<T>::const_iterator begin,typename std::vector<T>::const_iterator end)713 VectorSpan(typename std::vector<T>::const_iterator begin,
714 typename std::vector<T>::const_iterator end)
715 : begin_(begin), end_(end) {}
716
717 const T& operator[](typename std::vector<T>::size_type i) const {
718 return *(begin_ + i);
719 }
720
size()721 int size() const { return end_ - begin_; }
begin()722 typename std::vector<T>::const_iterator begin() const { return begin_; }
end()723 typename std::vector<T>::const_iterator end() const { return end_; }
data()724 const float* data() const { return &(*begin_); }
725
726 private:
727 typename std::vector<T>::const_iterator begin_;
728 typename std::vector<T>::const_iterator end_;
729 };
730
731 // Class to provide representation of date and time expressions
732 class DatetimeParsedData {
733 public:
734 // Function to set the absolute value of DateTimeComponent for the given
735 // FieldType, if the field is not present it will create the field and set
736 // the value.
737 void SetAbsoluteValue(const DatetimeComponent::ComponentType& field_type,
738 int value);
739
740 // Function to set the relative value of DateTimeComponent, if the field is
741 // not present the function will create the field and set the relative value.
742 void SetRelativeValue(
743 const DatetimeComponent::ComponentType& field_type,
744 const DatetimeComponent::RelativeQualifier& relative_value);
745
746 // Add collection of 'DatetimeComponent' to 'DatetimeParsedData'.
747 void AddDatetimeComponents(
748 const std::vector<DatetimeComponent>& datetime_components);
749
750 // Function to set the relative count of DateTimeComponent, if the field is
751 // not present the function will create the field and set the count.
752 void SetRelativeCount(const DatetimeComponent::ComponentType& field_type,
753 int relative_count);
754
755 // Function to populate the absolute value of the FieldType and return true.
756 // In case of no FieldType function will return false.
757 bool GetFieldValue(const DatetimeComponent::ComponentType& field_type,
758 int* field_value) const;
759
760 // Function to populate the relative value of the FieldType and return true.
761 // In case of no relative value function will return false.
762 bool GetRelativeValue(
763 const DatetimeComponent::ComponentType& field_type,
764 DatetimeComponent::RelativeQualifier* relative_value) const;
765
766 // Returns relative DateTimeComponent from the parsed DateTime span.
767 void GetRelativeDatetimeComponents(
768 std::vector<DatetimeComponent>* date_time_components) const;
769
770 // Returns DateTimeComponent from the parsed DateTime span.
771 void GetDatetimeComponents(
772 std::vector<DatetimeComponent>* date_time_components) const;
773
774 // Represent the granularity of the Parsed DateTime span. The function will
775 // return “GRANULARITY_UNKNOWN” if no datetime field is set.
776 DatetimeGranularity GetFinestGranularity() const;
777
778 // Utility function to check if DateTimeParsedData has FieldType initialized.
779 bool HasFieldType(const DatetimeComponent::ComponentType& field_type) const;
780
781 // Function to check if DateTimeParsedData has relative DateTimeComponent for
782 // given FieldType.
783 bool HasRelativeValue(
784 const DatetimeComponent::ComponentType& field_type) const;
785
786 // Function to check if DateTimeParsedData has absolute value
787 // DateTimeComponent for given FieldType.
788 bool HasAbsoluteValue(
789 const DatetimeComponent::ComponentType& field_type) const;
790
791 // Function to check if DateTimeParsedData has any DateTimeComponent.
792 bool IsEmpty() const;
793
794 private:
795 DatetimeComponent& GetOrCreateDatetimeComponent(
796
797 const DatetimeComponent::ComponentType& component_type);
798
799 std::map<DatetimeComponent::ComponentType, DatetimeComponent>
800 date_time_components_;
801 };
802
803 // Pretty-printing function for DateTimeParsedData.
804 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
805 const DatetimeParsedData& data);
806
807 } // namespace libtextclassifier3
808
809 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
810