1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
18 #define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
19
20 #include <time.h>
21
22 #include <algorithm>
23 #include <cmath>
24 #include <functional>
25 #include <map>
26 #include <set>
27 #include <string>
28 #include <unordered_set>
29 #include <utility>
30 #include <vector>
31
32 #include "annotator/entity-data_generated.h"
33 #include "annotator/knowledge/knowledge-engine-types.h"
34 #include "utils/base/integral_types.h"
35 #include "utils/base/logging.h"
36 #include "utils/flatbuffers/flatbuffers.h"
37 #include "utils/optional.h"
38 #include "utils/variant.h"
39
40 namespace libtextclassifier3 {
41
42 constexpr int kInvalidIndex = -1;
43 constexpr int kSunday = 1;
44 constexpr int kMonday = 2;
45 constexpr int kTuesday = 3;
46 constexpr int kWednesday = 4;
47 constexpr int kThursday = 5;
48 constexpr int kFriday = 6;
49 constexpr int kSaturday = 7;
50
51 // Index for a 0-based array of tokens.
52 using TokenIndex = int;
53
54 // Index for a 0-based array of codepoints.
55 using CodepointIndex = int;
56
57 // Marks a span in a sequence of codepoints. The first element is the index of
58 // the first codepoint of the span, and the second element is the index of the
59 // codepoint one past the end of the span.
60 struct CodepointSpan {
61 static const CodepointSpan kInvalid;
62
CodepointSpanCodepointSpan63 CodepointSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
64
CodepointSpanCodepointSpan65 CodepointSpan(CodepointIndex start, CodepointIndex end)
66 : first(start), second(end) {}
67
68 CodepointSpan& operator=(const CodepointSpan& other) = default;
69
70 bool operator==(const CodepointSpan& other) const {
71 return this->first == other.first && this->second == other.second;
72 }
73
74 bool operator!=(const CodepointSpan& other) const {
75 return !(*this == other);
76 }
77
78 bool operator<(const CodepointSpan& other) const {
79 if (this->first != other.first) {
80 return this->first < other.first;
81 }
82 return this->second < other.second;
83 }
84
IsValidCodepointSpan85 bool IsValid() const {
86 return this->first != kInvalidIndex && this->second != kInvalidIndex &&
87 this->first <= this->second && this->first >= 0;
88 }
89
IsEmptyCodepointSpan90 bool IsEmpty() const { return this->first == this->second; }
91
92 CodepointIndex first;
93 CodepointIndex second;
94 };
95
96 // Pretty-printing function for CodepointSpan.
97 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
98 const CodepointSpan& span);
99
SpansOverlap(const CodepointSpan & a,const CodepointSpan & b)100 inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
101 return a.first < b.second && b.first < a.second;
102 }
103
SpanContains(const CodepointSpan & span,const CodepointSpan & sub_span)104 inline bool SpanContains(const CodepointSpan& span,
105 const CodepointSpan& sub_span) {
106 return span.first <= sub_span.first && span.second >= sub_span.second;
107 }
108
109 template <typename T>
DoesCandidateConflict(const int considered_candidate,const std::vector<T> & candidates,const std::set<int,std::function<bool (int,int)>> & chosen_indices_set)110 bool DoesCandidateConflict(
111 const int considered_candidate, const std::vector<T>& candidates,
112 const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
113 if (chosen_indices_set.empty()) {
114 return false;
115 }
116
117 auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
118 // Check conflict on the right.
119 if (conflicting_it != chosen_indices_set.end() &&
120 SpansOverlap(candidates[considered_candidate].span,
121 candidates[*conflicting_it].span)) {
122 return true;
123 }
124
125 // Check conflict on the left.
126 // If we can't go more left, there can't be a conflict:
127 if (conflicting_it == chosen_indices_set.begin()) {
128 return false;
129 }
130 // Otherwise move one span left and insert if it doesn't overlap with the
131 // candidate.
132 --conflicting_it;
133 if (!SpansOverlap(candidates[considered_candidate].span,
134 candidates[*conflicting_it].span)) {
135 return false;
136 }
137
138 return true;
139 }
140
141 // Marks a span in a sequence of tokens. The first element is the index of the
142 // first token in the span, and the second element is the index of the token one
143 // past the end of the span.
144 struct TokenSpan {
145 static const TokenSpan kInvalid;
146
TokenSpanTokenSpan147 TokenSpan() : first(kInvalidIndex), second(kInvalidIndex) {}
148
TokenSpanTokenSpan149 TokenSpan(TokenIndex start, TokenIndex end) : first(start), second(end) {}
150
151 // Creates a token span consisting of one token.
TokenSpanTokenSpan152 explicit TokenSpan(int token_index)
153 : first(token_index), second(token_index + 1) {}
154
155 TokenSpan& operator=(const TokenSpan& other) = default;
156
157 bool operator==(const TokenSpan& other) const {
158 return this->first == other.first && this->second == other.second;
159 }
160
161 bool operator!=(const TokenSpan& other) const { return !(*this == other); }
162
163 bool operator<(const TokenSpan& other) const {
164 if (this->first != other.first) {
165 return this->first < other.first;
166 }
167 return this->second < other.second;
168 }
169
IsValidTokenSpan170 bool IsValid() const {
171 return this->first != kInvalidIndex && this->second != kInvalidIndex;
172 }
173
174 // Returns the size of the token span. Assumes that the span is valid.
SizeTokenSpan175 int Size() const { return this->second - this->first; }
176
177 // Returns an expanded token span by adding a certain number of tokens on its
178 // left and on its right.
ExpandTokenSpan179 TokenSpan Expand(int num_tokens_left, int num_tokens_right) const {
180 return {this->first - num_tokens_left, this->second + num_tokens_right};
181 }
182
183 TokenIndex first;
184 TokenIndex second;
185 };
186
187 // Pretty-printing function for TokenSpan.
188 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
189 const TokenSpan& span);
190
191 // Returns an intersection of two token spans. Assumes that both spans are
192 // valid and overlapping.
IntersectTokenSpans(const TokenSpan & token_span1,const TokenSpan & token_span2)193 inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
194 const TokenSpan& token_span2) {
195 return {std::max(token_span1.first, token_span2.first),
196 std::min(token_span1.second, token_span2.second)};
197 }
198
199 // Token holds a token, its position in the original string and whether it was
200 // part of the input span.
201 struct Token {
202 std::string value;
203 CodepointIndex start;
204 CodepointIndex end;
205
206 // Whether the token is a padding token.
207 bool is_padding;
208
209 // Whether the token contains only white characters.
210 bool is_whitespace;
211
212 // Default constructor constructs the padding-token.
TokenToken213 Token()
214 : Token(/*arg_value=*/"", /*arg_start=*/kInvalidIndex,
215 /*arg_end=*/kInvalidIndex, /*is_padding=*/true,
216 /*is_whitespace=*/false) {}
217
TokenToken218 Token(const std::string& arg_value, CodepointIndex arg_start,
219 CodepointIndex arg_end)
220 : Token(/*arg_value=*/arg_value, /*arg_start=*/arg_start,
221 /*arg_end=*/arg_end, /*is_padding=*/false,
222 /*is_whitespace=*/false) {}
223
TokenToken224 Token(const std::string& arg_value, CodepointIndex arg_start,
225 CodepointIndex arg_end, bool is_padding, bool is_whitespace)
226 : value(arg_value),
227 start(arg_start),
228 end(arg_end),
229 is_padding(is_padding),
230 is_whitespace(is_whitespace) {}
231
232 bool operator==(const Token& other) const {
233 return value == other.value && start == other.start && end == other.end &&
234 is_padding == other.is_padding;
235 }
236
IsContainedInSpanToken237 bool IsContainedInSpan(const CodepointSpan& span) const {
238 return start >= span.first && end <= span.second;
239 }
240 };
241
242 // Pretty-printing function for Token.
243 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
244 const Token& token);
245
246 // Returns a TokenSpan that merges all of the given token spans.
AllOf(const std::vector<Token> & tokens)247 inline TokenSpan AllOf(const std::vector<Token>& tokens) {
248 return {0, static_cast<TokenIndex>(tokens.size())};
249 }
250
251 enum DatetimeGranularity {
252 GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this
253 // structure being uninitialized.
254 GRANULARITY_YEAR = 0,
255 GRANULARITY_MONTH = 1,
256 GRANULARITY_WEEK = 2,
257 GRANULARITY_DAY = 3,
258 GRANULARITY_HOUR = 4,
259 GRANULARITY_MINUTE = 5,
260 GRANULARITY_SECOND = 6
261 };
262
263 // This struct represents a unit of date and time expression.
264 // Examples include:
265 // - In {March 21, 2019} datetime components are month: {March},
266 // day of month: {21} and year: {2019}.
267 // - {8:00 am} contains hour: {8}, minutes: {0} and am/pm: {am}
268 struct DatetimeComponent {
269 enum class ComponentType {
270 UNSPECIFIED = 0,
271 // Year of the date seen in the text match.
272 YEAR = 1,
273 // Month of the year starting with January = 1.
274 MONTH = 2,
275 // Week (7 days).
276 WEEK = 3,
277 // Day of week, start of the week is Sunday & its value is 1.
278 DAY_OF_WEEK = 4,
279 // Day of the month starting with 1.
280 DAY_OF_MONTH = 5,
281 // Hour of the day with a range of 0-23,
282 // values less than 12 need the AMPM field below or heuristics
283 // to definitively determine the time.
284 HOUR = 6,
285 // Minute of the hour with a range of 0-59.
286 MINUTE = 7,
287 // Seconds of the minute with a range of 0-59.
288 SECOND = 8,
289 // Meridiem field where 0 == AM, 1 == PM.
290 MERIDIEM = 9,
291 // Offset in number of minutes from UTC this date time is in.
292 ZONE_OFFSET = 10,
293 // Offset in number of hours for DST.
294 DST_OFFSET = 11,
295 };
296
297 // TODO(hassan): Remove RelativeQualifier as in the presence of relative
298 // count RelativeQualifier is redundant.
299 // Enum to represent the relative DateTimeComponent e.g. "next Monday",
300 // "the following day", "tomorrow".
301 enum class RelativeQualifier {
302 UNSPECIFIED = 0,
303 NEXT = 1,
304 THIS = 2,
305 LAST = 3,
306 NOW = 4,
307 TOMORROW = 5,
308 YESTERDAY = 6,
309 PAST = 7,
310 FUTURE = 8
311 };
312
313 bool operator==(const DatetimeComponent& other) const {
314 return component_type == other.component_type &&
315 relative_qualifier == other.relative_qualifier &&
316 relative_count == other.relative_count && value == other.value;
317 }
318
319 bool ShouldRoundToGranularity() const;
320
321 ComponentType component_type = ComponentType::UNSPECIFIED;
322 RelativeQualifier relative_qualifier = RelativeQualifier::UNSPECIFIED;
323
324 // Represents the absolute value of DateTime components.
325 int value = 0;
326 // The number of units of change present in the relative DateTimeComponent.
327 int relative_count = 0;
328
329 DatetimeComponent() = default;
330
DatetimeComponentDatetimeComponent331 explicit DatetimeComponent(ComponentType arg_component_type,
332 RelativeQualifier arg_relative_qualifier,
333 int arg_value, int arg_relative_count)
334 : component_type(arg_component_type),
335 relative_qualifier(arg_relative_qualifier),
336 value(arg_value),
337 relative_count(arg_relative_count) {}
338 };
339
340 // Utility method to calculate Returns the finest granularity of
341 // DatetimeComponents.
342 DatetimeGranularity GetFinestGranularity(
343 const std::vector<DatetimeComponent>& datetime_component);
344
345 // Return the 'DatetimeComponent' from collection filter by component type.
346 Optional<DatetimeComponent> GetDatetimeComponent(
347 const std::vector<DatetimeComponent>& datetime_components,
348 const DatetimeComponent::ComponentType& component_type);
349
350 struct DatetimeParseResult {
351 // The absolute time in milliseconds since the epoch in UTC.
352 int64 time_ms_utc;
353
354 // The precision of the estimate then in to calculating the milliseconds
355 DatetimeGranularity granularity;
356
357 // List of parsed DateTimeComponent.
358 std::vector<DatetimeComponent> datetime_components;
359
DatetimeParseResultDatetimeParseResult360 DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
361
DatetimeParseResultDatetimeParseResult362 DatetimeParseResult(int64 arg_time_ms_utc,
363 DatetimeGranularity arg_granularity,
364 std::vector<DatetimeComponent> arg_datetime__components)
365 : time_ms_utc(arg_time_ms_utc),
366 granularity(arg_granularity),
367 datetime_components(arg_datetime__components) {}
368
IsSetDatetimeParseResult369 bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
370
371 bool operator==(const DatetimeParseResult& other) const {
372 return granularity == other.granularity &&
373 time_ms_utc == other.time_ms_utc &&
374 datetime_components == other.datetime_components;
375 }
376 };
377
378 const float kFloatCompareEpsilon = 1e-5;
379
380 struct DatetimeParseResultSpan {
381 CodepointSpan span;
382 std::vector<DatetimeParseResult> data;
383 float target_classification_score;
384 float priority_score;
385
DatetimeParseResultSpanDatetimeParseResultSpan386 DatetimeParseResultSpan()
387 : span(CodepointSpan::kInvalid),
388 target_classification_score(-1.0),
389 priority_score(-1.0) {}
390
DatetimeParseResultSpanDatetimeParseResultSpan391 DatetimeParseResultSpan(const CodepointSpan& span,
392 const std::vector<DatetimeParseResult>& data,
393 const float target_classification_score,
394 const float priority_score)
395 : span(span),
396 data(data),
397 target_classification_score(target_classification_score),
398 priority_score(priority_score) {}
399
400 bool operator==(const DatetimeParseResultSpan& other) const {
401 return span == other.span && data == other.data &&
402 std::abs(target_classification_score -
403 other.target_classification_score) < kFloatCompareEpsilon &&
404 std::abs(priority_score - other.priority_score) <
405 kFloatCompareEpsilon;
406 }
407 };
408
409 // Pretty-printing function for DatetimeParseResultSpan.
410 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
411 const DatetimeParseResultSpan& value);
412
413 // This struct contains information intended to uniquely identify a device
414 // contact. Instances are created by the Knowledge Engine, and dereferenced by
415 // the Contact Engine.
416 struct ContactPointer {
417 std::string focus_contact_id;
418 std::string device_id;
419 std::string device_contact_id;
420 std::string contact_name;
421 std::string contact_name_hash;
422
423 bool operator==(const ContactPointer& other) const {
424 return focus_contact_id == other.focus_contact_id &&
425 device_id == other.device_id &&
426 device_contact_id == other.device_contact_id &&
427 contact_name == other.contact_name &&
428 contact_name_hash == other.contact_name_hash;
429 }
430 };
431
432 struct ClassificationResult {
433 std::string collection;
434 float score;
435 DatetimeParseResult datetime_parse_result;
436 std::string serialized_knowledge_result;
437 ContactPointer contact_pointer;
438 std::string contact_name, contact_given_name, contact_family_name,
439 contact_nickname, contact_email_address, contact_phone_number,
440 contact_account_type, contact_account_name, contact_id;
441 std::string app_name, app_package_name;
442 int64 numeric_value;
443 double numeric_double_value;
444
445 // Length of the parsed duration in milliseconds.
446 int64 duration_ms;
447
448 // Internal score used for conflict resolution.
449 float priority_score;
450
451
452 // Entity data information.
453 std::string serialized_entity_data;
entity_dataClassificationResult454 const EntityData* entity_data() const {
455 return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
456 serialized_entity_data.size());
457 }
458
ClassificationResultClassificationResult459 explicit ClassificationResult()
460 : score(-1.0f),
461 numeric_value(0),
462 numeric_double_value(0.),
463 duration_ms(0),
464 priority_score(-1.0) {}
465
ClassificationResultClassificationResult466 ClassificationResult(const std::string& arg_collection, float arg_score)
467 : collection(arg_collection),
468 score(arg_score),
469 numeric_value(0),
470 numeric_double_value(0.),
471 duration_ms(0),
472 priority_score(arg_score) {}
473
ClassificationResultClassificationResult474 ClassificationResult(const std::string& arg_collection, float arg_score,
475 float arg_priority_score)
476 : collection(arg_collection),
477 score(arg_score),
478 numeric_value(0),
479 numeric_double_value(0.),
480 duration_ms(0),
481 priority_score(arg_priority_score) {}
482
483 bool operator!=(const ClassificationResult& other) const {
484 return !(*this == other);
485 }
486
487 bool operator==(const ClassificationResult& other) const;
488 };
489
490 // Aliases for long enum values.
491 const AnnotationUsecase ANNOTATION_USECASE_SMART =
492 AnnotationUsecase_ANNOTATION_USECASE_SMART;
493 const AnnotationUsecase ANNOTATION_USECASE_RAW =
494 AnnotationUsecase_ANNOTATION_USECASE_RAW;
495
496 struct LocationContext {
497 // User location latitude in degrees.
498 double user_location_lat = 180.;
499
500 // User location longitude in degrees.
501 double user_location_lng = 360.;
502
503 // The estimated horizontal accuracy of the user location in meters.
504 // Analogous to android.location.Location accuracy.
505 float user_location_accuracy_meters = 0.f;
506
507 bool operator==(const LocationContext& other) const {
508 return std::fabs(this->user_location_lat - other.user_location_lat) <
509 1e-8 &&
510 std::fabs(this->user_location_lng - other.user_location_lng) <
511 1e-8 &&
512 std::fabs(this->user_location_accuracy_meters -
513 other.user_location_accuracy_meters) < 1e-8;
514 }
515 };
516
517 struct BaseOptions {
518 // Comma-separated list of locale specification for the input text (BCP 47
519 // tags).
520 std::string locales;
521
522 // Comma-separated list of BCP 47 language tags.
523 std::string detected_text_language_tags;
524
525 // Tailors the output annotations according to the specified use-case.
526 AnnotationUsecase annotation_usecase = ANNOTATION_USECASE_SMART;
527
528 // The location context passed along with each annotation.
529 Optional<LocationContext> location_context;
530
531 // If true, the POD NER annotator is used.
532 bool use_pod_ner = true;
533
534 // If true and the model file supports that, the new vocab annotator is used
535 // to annotate "Dictionary". Otherwise, we use the FFModel to do so.
536 bool use_vocab_annotator = true;
537
538 bool operator==(const BaseOptions& other) const {
539 bool location_context_equality = this->location_context.has_value() ==
540 other.location_context.has_value();
541 if (this->location_context.has_value() &&
542 other.location_context.has_value()) {
543 location_context_equality =
544 this->location_context.value() == other.location_context.value();
545 }
546 return this->locales == other.locales &&
547 this->annotation_usecase == other.annotation_usecase &&
548 this->detected_text_language_tags ==
549 other.detected_text_language_tags &&
550 location_context_equality &&
551 this->use_pod_ner == other.use_pod_ner &&
552 this->use_vocab_annotator == other.use_vocab_annotator;
553 }
554 };
555
556 struct DatetimeOptions {
557 // For parsing relative datetimes, the reference now time against which the
558 // relative datetimes get resolved.
559 // UTC milliseconds since epoch.
560 int64 reference_time_ms_utc = 0;
561
562 // Timezone in which the input text was written (format as accepted by ICU).
563 std::string reference_timezone;
564
565 bool operator==(const DatetimeOptions& other) const {
566 return this->reference_time_ms_utc == other.reference_time_ms_utc &&
567 this->reference_timezone == other.reference_timezone;
568 }
569 };
570
571 struct SelectionOptions : public BaseOptions {};
572
573 struct ClassificationOptions : public BaseOptions, public DatetimeOptions {
574 // Comma-separated list of language tags which the user can read and
575 // understand (BCP 47).
576 std::string user_familiar_language_tags;
577 // If true, trigger dictionary on words that are of beginner level.
578 bool trigger_dictionary_on_beginner_words = false;
579
580 bool operator==(const ClassificationOptions& other) const {
581 return this->user_familiar_language_tags ==
582 other.user_familiar_language_tags &&
583 this->trigger_dictionary_on_beginner_words ==
584 other.trigger_dictionary_on_beginner_words &&
585 BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
586 }
587 };
588
589 struct Permissions {
590 // If true the user location can be used to provide better annotations.
591 bool has_location_permission = true;
592 // If true, annotators can use personal data to provide personalized
593 // annotations.
594 bool has_personalization_permission = true;
595
596 bool operator==(const Permissions& other) const {
597 return this->has_location_permission == other.has_location_permission &&
598 this->has_personalization_permission ==
599 other.has_personalization_permission;
600 }
601 };
602
603 struct AnnotationOptions : public BaseOptions, public DatetimeOptions {
604 // List of entity types that should be used for annotation.
605 std::unordered_set<std::string> entity_types;
606
607 // If true, serialized_entity_data in the results is populated."
608 bool is_serialized_entity_data_enabled = false;
609
610 // Defines the permissions for the annotators.
611 Permissions permissions;
612
613 AnnotateMode annotate_mode = AnnotateMode::kEntityAnnotation;
614
615 // If true, trigger dictionary on words that are of beginner level.
616 bool trigger_dictionary_on_beginner_words = false;
617
618 // If true, enables an optimized code path for annotation.
619 // The optimization caused crashes previously, which is why we are rolling it
620 // out using this temporary flag. See: b/178503899
621 bool enable_optimization = false;
622
623 bool operator==(const AnnotationOptions& other) const {
624 return this->is_serialized_entity_data_enabled ==
625 other.is_serialized_entity_data_enabled &&
626 this->permissions == other.permissions &&
627 this->entity_types == other.entity_types &&
628 this->annotate_mode == other.annotate_mode &&
629 this->trigger_dictionary_on_beginner_words ==
630 other.trigger_dictionary_on_beginner_words &&
631 BaseOptions::operator==(other) && DatetimeOptions::operator==(other);
632 }
633 };
634
635 // Returns true when ClassificationResults are euqal up to scores.
636 bool ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
637 const ClassificationResult& a, const ClassificationResult& b);
638
639 // Pretty-printing function for ClassificationResult.
640 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
641 const ClassificationResult& result);
642
643 // Pretty-printing function for std::vector<ClassificationResult>.
644 logging::LoggingStringStream& operator<<(
645 logging::LoggingStringStream& stream,
646 const std::vector<ClassificationResult>& results);
647
648 // Represents a result of Annotate call.
649 struct AnnotatedSpan {
650 enum class Source { OTHER, KNOWLEDGE, DURATION, DATETIME, PERSON_NAME };
651
652 // Unicode codepoint indices in the input string.
653 CodepointSpan span = CodepointSpan::kInvalid;
654
655 // Classification result for the span.
656 std::vector<ClassificationResult> classification;
657
658 // The source of the annotation, used in conflict resolution.
659 Source source = Source::OTHER;
660
661 AnnotatedSpan() = default;
662
AnnotatedSpanAnnotatedSpan663 AnnotatedSpan(CodepointSpan arg_span,
664 std::vector<ClassificationResult> arg_classification)
665 : span(arg_span), classification(std::move(arg_classification)) {}
666
AnnotatedSpanAnnotatedSpan667 AnnotatedSpan(CodepointSpan arg_span,
668 std::vector<ClassificationResult> arg_classification,
669 Source arg_source)
670 : span(arg_span),
671 classification(std::move(arg_classification)),
672 source(arg_source) {}
673 };
674
675 // Represents Annotations that correspond to all input fragments.
676 struct Annotations {
677 // List of annotations found in the corresponding input fragments. For these
678 // annotations, topicality score will not be set.
679 std::vector<std::vector<AnnotatedSpan>> annotated_spans;
680
681 // List of topicality results found across all input fragments.
682 std::vector<ClassificationResult> topicality_results;
683
684 Annotations() = default;
685
AnnotationsAnnotations686 explicit Annotations(
687 std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans)
688 : annotated_spans(std::move(arg_annotated_spans)) {}
689
AnnotationsAnnotations690 Annotations(std::vector<std::vector<AnnotatedSpan>> arg_annotated_spans,
691 std::vector<ClassificationResult> arg_topicality_results)
692 : annotated_spans(std::move(arg_annotated_spans)),
693 topicality_results(std::move(arg_topicality_results)) {}
694 };
695
696 struct InputFragment {
697 std::string text;
698 float bounding_box_top;
699 float bounding_box_height;
700
701 // If present will override the AnnotationOptions reference time and timezone
702 // when annotating this specific string fragment.
703 Optional<DatetimeOptions> datetime_options;
704 };
705
706 // Pretty-printing function for AnnotatedSpan.
707 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
708 const AnnotatedSpan& span);
709
710 // StringPiece analogue for std::vector<T>.
711 template <class T>
712 class VectorSpan {
713 public:
VectorSpan()714 VectorSpan() : begin_(), end_() {}
VectorSpan(const std::vector<T> & v)715 explicit VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
716 : begin_(v.begin()), end_(v.end()) {}
VectorSpan(typename std::vector<T>::const_iterator begin,typename std::vector<T>::const_iterator end)717 VectorSpan(typename std::vector<T>::const_iterator begin,
718 typename std::vector<T>::const_iterator end)
719 : begin_(begin), end_(end) {}
720
721 const T& operator[](typename std::vector<T>::size_type i) const {
722 return *(begin_ + i);
723 }
724
size()725 int size() const { return end_ - begin_; }
begin()726 typename std::vector<T>::const_iterator begin() const { return begin_; }
end()727 typename std::vector<T>::const_iterator end() const { return end_; }
data()728 const float* data() const { return &(*begin_); }
729
730 private:
731 typename std::vector<T>::const_iterator begin_;
732 typename std::vector<T>::const_iterator end_;
733 };
734
735 // Class to provide representation of date and time expressions
736 class DatetimeParsedData {
737 public:
738 // Function to set the absolute value of DateTimeComponent for the given
739 // FieldType, if the field is not present it will create the field and set
740 // the value.
741 void SetAbsoluteValue(const DatetimeComponent::ComponentType& field_type,
742 int value);
743
744 // Function to set the relative value of DateTimeComponent, if the field is
745 // not present the function will create the field and set the relative value.
746 void SetRelativeValue(
747 const DatetimeComponent::ComponentType& field_type,
748 const DatetimeComponent::RelativeQualifier& relative_value);
749
750 // Add collection of 'DatetimeComponent' to 'DatetimeParsedData'.
751 void AddDatetimeComponents(
752 const std::vector<DatetimeComponent>& datetime_components);
753
754 // Function to set the relative count of DateTimeComponent, if the field is
755 // not present the function will create the field and set the count.
756 void SetRelativeCount(const DatetimeComponent::ComponentType& field_type,
757 int relative_count);
758
759 // Function to populate the absolute value of the FieldType and return true.
760 // In case of no FieldType function will return false.
761 bool GetFieldValue(const DatetimeComponent::ComponentType& field_type,
762 int* field_value) const;
763
764 // Function to populate the relative value of the FieldType and return true.
765 // In case of no relative value function will return false.
766 bool GetRelativeValue(
767 const DatetimeComponent::ComponentType& field_type,
768 DatetimeComponent::RelativeQualifier* relative_value) const;
769
770 // Returns relative DateTimeComponent from the parsed DateTime span.
771 void GetRelativeDatetimeComponents(
772 std::vector<DatetimeComponent>* date_time_components) const;
773
774 // Returns DateTimeComponent from the parsed DateTime span.
775 void GetDatetimeComponents(
776 std::vector<DatetimeComponent>* date_time_components) const;
777
778 // Represent the granularity of the Parsed DateTime span. The function will
779 // return “GRANULARITY_UNKNOWN” if no datetime field is set.
780 DatetimeGranularity GetFinestGranularity() const;
781
782 // Utility function to check if DateTimeParsedData has FieldType initialized.
783 bool HasFieldType(const DatetimeComponent::ComponentType& field_type) const;
784
785 // Function to check if DateTimeParsedData has relative DateTimeComponent for
786 // given FieldType.
787 bool HasRelativeValue(
788 const DatetimeComponent::ComponentType& field_type) const;
789
790 // Function to check if DateTimeParsedData has absolute value
791 // DateTimeComponent for given FieldType.
792 bool HasAbsoluteValue(
793 const DatetimeComponent::ComponentType& field_type) const;
794
795 // Function to check if DateTimeParsedData has any DateTimeComponent.
796 bool IsEmpty() const;
797
798 private:
799 DatetimeComponent& GetOrCreateDatetimeComponent(
800
801 const DatetimeComponent::ComponentType& component_type);
802
803 std::map<DatetimeComponent::ComponentType, DatetimeComponent>
804 date_time_components_;
805 };
806
807 // Pretty-printing function for DateTimeParsedData.
808 logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
809 const DatetimeParsedData& data);
810
811 } // namespace libtextclassifier3
812
813 #endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_
814