1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/datetime/testing/base-parser-test.h"
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "utils/i18n/locale-list.h"
24 #include "gmock/gmock.h"
25 #include "gtest/gtest.h"
26
27 using std::vector;
28 using testing::ElementsAreArray;
29
30 namespace libtextclassifier3 {
31
HasNoResult(const std::string & text,bool anchor_start_end,const std::string & timezone,AnnotationUsecase annotation_usecase)32 bool DateTimeParserTest::HasNoResult(const std::string& text,
33 bool anchor_start_end,
34 const std::string& timezone,
35 AnnotationUsecase annotation_usecase) {
36 StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
37 DatetimeParserForTests()->Parse(
38 text, 0, timezone, LocaleList::ParseFrom(/*locale_tags=*/""),
39 ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
40 if (!results_status.ok()) {
41 TC3_LOG(ERROR) << text;
42 TC3_CHECK(false);
43 }
44 return results_status.ValueOrDie().empty();
45 }
46
ParsesCorrectly(const std::string & marked_text,const vector<int64> & expected_ms_utcs,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components,bool anchor_start_end,const std::string & timezone,const std::string & locales,AnnotationUsecase annotation_usecase)47 bool DateTimeParserTest::ParsesCorrectly(
48 const std::string& marked_text, const vector<int64>& expected_ms_utcs,
49 DatetimeGranularity expected_granularity,
50 vector<vector<DatetimeComponent>> datetime_components,
51 bool anchor_start_end, const std::string& timezone,
52 const std::string& locales, AnnotationUsecase annotation_usecase) {
53 const UnicodeText marked_text_unicode =
54 UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
55 auto brace_open_it =
56 std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
57 auto brace_end_it =
58 std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
59 TC3_CHECK(brace_open_it != marked_text_unicode.end());
60 TC3_CHECK(brace_end_it != marked_text_unicode.end());
61
62 std::string text;
63 text +=
64 UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
65 text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
66 text += UnicodeText::UTF8Substring(std::next(brace_end_it),
67 marked_text_unicode.end());
68
69 StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
70 DatetimeParserForTests()->Parse(
71 text, 0, timezone, LocaleList::ParseFrom(locales),
72 ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
73 if (!results_status.ok()) {
74 TC3_LOG(ERROR) << text;
75 TC3_CHECK(false);
76 }
77 // const std::vector<DatetimeParseResultSpan>& results =
78 // results_status.ValueOrDie();
79 if (results_status.ValueOrDie().empty()) {
80 TC3_LOG(ERROR) << "No results.";
81 return false;
82 }
83
84 const int expected_start_index =
85 std::distance(marked_text_unicode.begin(), brace_open_it);
86 // The -1 below is to account for the opening bracket character.
87 const int expected_end_index =
88 std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
89
90 std::vector<DatetimeParseResultSpan> filtered_results;
91 for (const DatetimeParseResultSpan& result : results_status.ValueOrDie()) {
92 if (SpansOverlap(result.span, {expected_start_index, expected_end_index})) {
93 filtered_results.push_back(result);
94 }
95 }
96 std::vector<DatetimeParseResultSpan> expected{
97 {{expected_start_index, expected_end_index},
98 {},
99 /*target_classification_score=*/1.0,
100 /*priority_score=*/1.0}};
101 expected[0].data.resize(expected_ms_utcs.size());
102 for (int i = 0; i < expected_ms_utcs.size(); i++) {
103 expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
104 datetime_components[i]};
105 }
106
107 const bool matches =
108 testing::Matches(ElementsAreArray(expected))(filtered_results);
109 if (!matches) {
110 TC3_LOG(ERROR) << "Expected: " << expected[0];
111 if (filtered_results.empty()) {
112 TC3_LOG(ERROR) << "But got no results.";
113 }
114 TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
115 }
116
117 return matches;
118 }
119
ParsesCorrectly(const std::string & marked_text,const int64 expected_ms_utc,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components,bool anchor_start_end,const std::string & timezone,const std::string & locales,AnnotationUsecase annotation_usecase)120 bool DateTimeParserTest::ParsesCorrectly(
121 const std::string& marked_text, const int64 expected_ms_utc,
122 DatetimeGranularity expected_granularity,
123 vector<vector<DatetimeComponent>> datetime_components,
124 bool anchor_start_end, const std::string& timezone,
125 const std::string& locales, AnnotationUsecase annotation_usecase) {
126 return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
127 expected_granularity, datetime_components,
128 anchor_start_end, timezone, locales,
129 annotation_usecase);
130 }
131
ParsesCorrectlyGerman(const std::string & marked_text,const vector<int64> & expected_ms_utcs,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components)132 bool DateTimeParserTest::ParsesCorrectlyGerman(
133 const std::string& marked_text, const vector<int64>& expected_ms_utcs,
134 DatetimeGranularity expected_granularity,
135 vector<vector<DatetimeComponent>> datetime_components) {
136 return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
137 datetime_components,
138 /*anchor_start_end=*/false,
139 /*timezone=*/"Europe/Zurich", /*locales=*/"de");
140 }
141
ParsesCorrectlyGerman(const std::string & marked_text,const int64 expected_ms_utc,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components)142 bool DateTimeParserTest::ParsesCorrectlyGerman(
143 const std::string& marked_text, const int64 expected_ms_utc,
144 DatetimeGranularity expected_granularity,
145 vector<vector<DatetimeComponent>> datetime_components) {
146 return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
147 datetime_components,
148 /*anchor_start_end=*/false,
149 /*timezone=*/"Europe/Zurich", /*locales=*/"de");
150 }
151
ParsesCorrectlyChinese(const std::string & marked_text,const int64 expected_ms_utc,DatetimeGranularity expected_granularity,vector<vector<DatetimeComponent>> datetime_components)152 bool DateTimeParserTest::ParsesCorrectlyChinese(
153 const std::string& marked_text, const int64 expected_ms_utc,
154 DatetimeGranularity expected_granularity,
155 vector<vector<DatetimeComponent>> datetime_components) {
156 return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
157 datetime_components,
158 /*anchor_start_end=*/false,
159 /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
160 }
161
162 } // namespace libtextclassifier3
163