• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/duration/duration.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "annotator/collections.h"
23 #include "annotator/model_generated.h"
24 #include "annotator/types-test-util.h"
25 #include "annotator/types.h"
26 #include "utils/tokenizer-utils.h"
27 #include "utils/utf8/unicodetext.h"
28 #include "utils/utf8/unilib.h"
29 #include "gmock/gmock.h"
30 #include "gtest/gtest.h"
31 
32 namespace libtextclassifier3 {
33 namespace {
34 
35 using testing::AllOf;
36 using testing::ElementsAre;
37 using testing::Field;
38 using testing::IsEmpty;
39 
TestingDurationAnnotatorOptions()40 const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
41   static const flatbuffers::DetachedBuffer* options_data = []() {
42     DurationAnnotatorOptionsT options;
43     options.enabled = true;
44 
45     options.week_expressions.push_back("week");
46     options.week_expressions.push_back("weeks");
47 
48     options.day_expressions.push_back("day");
49     options.day_expressions.push_back("days");
50 
51     options.hour_expressions.push_back("hour");
52     options.hour_expressions.push_back("hours");
53 
54     options.minute_expressions.push_back("minute");
55     options.minute_expressions.push_back("minutes");
56 
57     options.second_expressions.push_back("second");
58     options.second_expressions.push_back("seconds");
59 
60     options.filler_expressions.push_back("and");
61     options.filler_expressions.push_back("a");
62     options.filler_expressions.push_back("an");
63     options.filler_expressions.push_back("one");
64 
65     options.half_expressions.push_back("half");
66 
67     options.sub_token_separator_codepoints.push_back('-');
68 
69     flatbuffers::FlatBufferBuilder builder;
70     builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
71     return new flatbuffers::DetachedBuffer(builder.Release());
72   }();
73 
74   return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
75 }
76 
BuildFeatureProcessor(const UniLib * unilib)77 std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
78   static const flatbuffers::DetachedBuffer* options_data = []() {
79     FeatureProcessorOptionsT options;
80     options.context_size = 1;
81     options.max_selection_span = 1;
82     options.snap_label_span_boundaries_to_containing_tokens = false;
83     options.ignored_span_boundary_codepoints.push_back(',');
84 
85     options.tokenization_codepoint_config.emplace_back(
86         new TokenizationCodepointRangeT());
87     auto& config = options.tokenization_codepoint_config.back();
88     config->start = 32;
89     config->end = 33;
90     config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
91 
92     flatbuffers::FlatBufferBuilder builder;
93     builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
94     return new flatbuffers::DetachedBuffer(builder.Release());
95   }();
96 
97   const FeatureProcessorOptions* feature_processor_options =
98       flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
99 
100   return std::unique_ptr<FeatureProcessor>(
101       new FeatureProcessor(feature_processor_options, unilib));
102 }
103 
104 class DurationAnnotatorTest : public ::testing::Test {
105  protected:
DurationAnnotatorTest()106   DurationAnnotatorTest()
107       : INIT_UNILIB_FOR_TESTING(unilib_),
108         feature_processor_(BuildFeatureProcessor(&unilib_)),
109         duration_annotator_(TestingDurationAnnotatorOptions(),
110                             feature_processor_.get(), &unilib_) {}
111 
Tokenize(const UnicodeText & text)112   std::vector<Token> Tokenize(const UnicodeText& text) {
113     return feature_processor_->Tokenize(text);
114   }
115 
116   UniLib unilib_;
117   std::unique_ptr<FeatureProcessor> feature_processor_;
118   DurationAnnotator duration_annotator_;
119 };
120 
TEST_F(DurationAnnotatorTest,ClassifiesSimpleDuration)121 TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
122   ClassificationResult classification;
123   EXPECT_TRUE(duration_annotator_.ClassifyText(
124       UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
125       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
126 
127   EXPECT_THAT(classification,
128               AllOf(Field(&ClassificationResult::collection, "duration"),
129                     Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
130 }
131 
TEST_F(DurationAnnotatorTest,ClassifiesWhenTokensDontAlignWithSelection)132 TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
133   ClassificationResult classification;
134   EXPECT_TRUE(duration_annotator_.ClassifyText(
135       UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
136       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
137 
138   EXPECT_THAT(classification,
139               AllOf(Field(&ClassificationResult::collection, "duration"),
140                     Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
141 }
142 
TEST_F(DurationAnnotatorTest,DoNotClassifyWhenInputIsInvalid)143 TEST_F(DurationAnnotatorTest, DoNotClassifyWhenInputIsInvalid) {
144   ClassificationResult classification;
145   EXPECT_FALSE(duration_annotator_.ClassifyText(
146       UTF8ToUnicodeText("Weird space"), {5, 6},
147       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
148 }
149 
TEST_F(DurationAnnotatorTest,FindsSimpleDuration)150 TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
151   const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
152   std::vector<Token> tokens = Tokenize(text);
153   std::vector<AnnotatedSpan> result;
154   EXPECT_TRUE(duration_annotator_.FindAll(
155       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
156 
157   EXPECT_THAT(
158       result,
159       ElementsAre(
160           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
161                 Field(&AnnotatedSpan::classification,
162                       ElementsAre(AllOf(
163                           Field(&ClassificationResult::collection, "duration"),
164                           Field(&ClassificationResult::duration_ms,
165                                 15 * 60 * 1000)))))));
166 }
167 
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpression)168 TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
169   const UnicodeText text =
170       UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
171   std::vector<Token> tokens = Tokenize(text);
172   std::vector<AnnotatedSpan> result;
173   EXPECT_TRUE(duration_annotator_.FindAll(
174       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
175 
176   EXPECT_THAT(
177       result,
178       ElementsAre(
179           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
180                 Field(&AnnotatedSpan::classification,
181                       ElementsAre(AllOf(
182                           Field(&ClassificationResult::collection, "duration"),
183                           Field(&ClassificationResult::duration_ms,
184                                 3.5 * 60 * 1000)))))));
185 }
186 
TEST_F(DurationAnnotatorTest,FindsComposedDuration)187 TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
188   const UnicodeText text =
189       UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
190   std::vector<Token> tokens = Tokenize(text);
191   std::vector<AnnotatedSpan> result;
192   EXPECT_TRUE(duration_annotator_.FindAll(
193       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
194 
195   EXPECT_THAT(
196       result,
197       ElementsAre(
198           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
199                 Field(&AnnotatedSpan::classification,
200                       ElementsAre(AllOf(
201                           Field(&ClassificationResult::collection, "duration"),
202                           Field(&ClassificationResult::duration_ms,
203                                 3 * 60 * 60 * 1000 + 5 * 1000)))))));
204 }
205 
TEST_F(DurationAnnotatorTest,AllUnitsAreCovered)206 TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
207   const UnicodeText text = UTF8ToUnicodeText(
208       "See you in a week and a day and an hour and a minute and a second");
209   std::vector<Token> tokens = Tokenize(text);
210   std::vector<AnnotatedSpan> result;
211   EXPECT_TRUE(duration_annotator_.FindAll(
212       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
213 
214   EXPECT_THAT(
215       result,
216       ElementsAre(
217           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)),
218                 Field(&AnnotatedSpan::classification,
219                       ElementsAre(AllOf(
220                           Field(&ClassificationResult::collection, "duration"),
221                           Field(&ClassificationResult::duration_ms,
222                                 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 +
223                                     60 * 60 * 1000 + 60 * 1000 + 1000)))))));
224 }
225 
TEST_F(DurationAnnotatorTest,FindsHalfAnHour)226 TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
227   const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
228   std::vector<Token> tokens = Tokenize(text);
229   std::vector<AnnotatedSpan> result;
230   EXPECT_TRUE(duration_annotator_.FindAll(
231       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
232 
233   EXPECT_THAT(
234       result,
235       ElementsAre(
236           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
237                 Field(&AnnotatedSpan::classification,
238                       ElementsAre(AllOf(
239                           Field(&ClassificationResult::collection, "duration"),
240                           Field(&ClassificationResult::duration_ms,
241                                 0.5 * 60 * 60 * 1000)))))));
242 }
243 
TEST_F(DurationAnnotatorTest,FindsWhenHalfIsAfterGranularitySpecification)244 TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
245   const UnicodeText text =
246       UTF8ToUnicodeText("Set a timer for 1 hour and a half");
247   std::vector<Token> tokens = Tokenize(text);
248   std::vector<AnnotatedSpan> result;
249   EXPECT_TRUE(duration_annotator_.FindAll(
250       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
251 
252   EXPECT_THAT(
253       result,
254       ElementsAre(
255           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
256                 Field(&AnnotatedSpan::classification,
257                       ElementsAre(AllOf(
258                           Field(&ClassificationResult::collection, "duration"),
259                           Field(&ClassificationResult::duration_ms,
260                                 1.5 * 60 * 60 * 1000)))))));
261 }
262 
TEST_F(DurationAnnotatorTest,FindsAnHourAndAHalf)263 TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
264   const UnicodeText text =
265       UTF8ToUnicodeText("Set a timer for an hour and a half");
266   std::vector<Token> tokens = Tokenize(text);
267   std::vector<AnnotatedSpan> result;
268   EXPECT_TRUE(duration_annotator_.FindAll(
269       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
270 
271   EXPECT_THAT(
272       result,
273       ElementsAre(
274           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
275                 Field(&AnnotatedSpan::classification,
276                       ElementsAre(AllOf(
277                           Field(&ClassificationResult::collection, "duration"),
278                           Field(&ClassificationResult::duration_ms,
279                                 1.5 * 60 * 60 * 1000)))))));
280 }
281 
TEST_F(DurationAnnotatorTest,FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber)282 TEST_F(DurationAnnotatorTest,
283        FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
284   const UnicodeText text =
285       UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
286   std::vector<Token> tokens = Tokenize(text);
287   std::vector<AnnotatedSpan> result;
288   EXPECT_TRUE(duration_annotator_.FindAll(
289       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
290 
291   EXPECT_THAT(
292       result,
293       ElementsAre(
294           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
295                 Field(&AnnotatedSpan::classification,
296                       ElementsAre(AllOf(
297                           Field(&ClassificationResult::collection, "duration"),
298                           Field(&ClassificationResult::duration_ms,
299                                 10 * 60 * 1000 + 1 * 1000)))))));
300 }
301 
TEST_F(DurationAnnotatorTest,DoesNotGreedilyTakeFillerWords)302 TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
303   const UnicodeText text = UTF8ToUnicodeText(
304       "Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
305   std::vector<Token> tokens = Tokenize(text);
306   std::vector<AnnotatedSpan> result;
307   EXPECT_TRUE(duration_annotator_.FindAll(
308       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
309 
310   EXPECT_THAT(
311       result,
312       ElementsAre(
313           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
314                 Field(&AnnotatedSpan::classification,
315                       ElementsAre(AllOf(
316                           Field(&ClassificationResult::collection, "duration"),
317                           Field(&ClassificationResult::duration_ms,
318                                 10 * 60 * 1000 + 2 * 1000)))))));
319 }
320 
TEST_F(DurationAnnotatorTest,DoesNotCrashWhenJustHalfIsSaid)321 TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
322   const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
323   std::vector<Token> tokens = Tokenize(text);
324   std::vector<AnnotatedSpan> result;
325   EXPECT_TRUE(duration_annotator_.FindAll(
326       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
327 
328   ASSERT_EQ(result.size(), 0);
329 }
330 
TEST_F(DurationAnnotatorTest,StripsPunctuationFromTokens)331 TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
332   const UnicodeText text =
333       UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
334   std::vector<Token> tokens = Tokenize(text);
335   std::vector<AnnotatedSpan> result;
336   EXPECT_TRUE(duration_annotator_.FindAll(
337       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
338 
339   EXPECT_THAT(
340       result,
341       ElementsAre(
342           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
343                 Field(&AnnotatedSpan::classification,
344                       ElementsAre(AllOf(
345                           Field(&ClassificationResult::collection, "duration"),
346                           Field(&ClassificationResult::duration_ms,
347                                 10 * 60 * 1000 + 2 * 1000)))))));
348 }
349 
TEST_F(DurationAnnotatorTest,FindsCorrectlyWithCombinedQuantityUnitToken)350 TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
351   const UnicodeText text = UTF8ToUnicodeText("Show 5-minute timer.");
352   std::vector<Token> tokens = Tokenize(text);
353   std::vector<AnnotatedSpan> result;
354   EXPECT_TRUE(duration_annotator_.FindAll(
355       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
356 
357   EXPECT_THAT(
358       result,
359       ElementsAre(
360           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(5, 13)),
361                 Field(&AnnotatedSpan::classification,
362                       ElementsAre(AllOf(
363                           Field(&ClassificationResult::collection, "duration"),
364                           Field(&ClassificationResult::duration_ms,
365                                 5 * 60 * 1000)))))));
366 }
367 
TEST_F(DurationAnnotatorTest,DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis)368 TEST_F(DurationAnnotatorTest,
369        DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis) {
370   ClassificationResult classification;
371   EXPECT_TRUE(duration_annotator_.ClassifyText(
372       UTF8ToUnicodeText("1400 hours"), {0, 10},
373       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
374 
375   EXPECT_THAT(classification,
376               AllOf(Field(&ClassificationResult::collection, "duration"),
377                     Field(&ClassificationResult::duration_ms,
378                           1400LL * 60LL * 60LL * 1000LL)));
379 }
380 
TEST_F(DurationAnnotatorTest,FindsSimpleDurationIgnoringCase)381 TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
382   const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?");
383   std::vector<Token> tokens = Tokenize(text);
384   std::vector<AnnotatedSpan> result;
385   EXPECT_TRUE(duration_annotator_.FindAll(
386       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
387 
388   EXPECT_THAT(
389       result,
390       ElementsAre(
391           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
392                 Field(&AnnotatedSpan::classification,
393                       ElementsAre(AllOf(
394                           Field(&ClassificationResult::collection, "duration"),
395                           Field(&ClassificationResult::duration_ms,
396                                 15 * 60 * 1000)))))));
397 }
398 
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpressionIgnoringCase)399 TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
400   const UnicodeText text =
401       UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?");
402   std::vector<Token> tokens = Tokenize(text);
403   std::vector<AnnotatedSpan> result;
404   EXPECT_TRUE(duration_annotator_.FindAll(
405       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
406 
407   EXPECT_THAT(
408       result,
409       ElementsAre(
410           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
411                 Field(&AnnotatedSpan::classification,
412                       ElementsAre(AllOf(
413                           Field(&ClassificationResult::collection, "duration"),
414                           Field(&ClassificationResult::duration_ms,
415                                 3.5 * 60 * 1000)))))));
416 }
417 
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpressionIgnoringFillerWordCase)418 TEST_F(DurationAnnotatorTest,
419        FindsDurationWithHalfExpressionIgnoringFillerWordCase) {
420   const UnicodeText text =
421       UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?");
422   std::vector<Token> tokens = Tokenize(text);
423   std::vector<AnnotatedSpan> result;
424   EXPECT_TRUE(duration_annotator_.FindAll(
425       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
426 
427   EXPECT_THAT(
428       result,
429       ElementsAre(
430           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
431                 Field(&AnnotatedSpan::classification,
432                       ElementsAre(AllOf(
433                           Field(&ClassificationResult::collection, "duration"),
434                           Field(&ClassificationResult::duration_ms,
435                                 3.5 * 60 * 1000)))))));
436 }
437 
TEST_F(DurationAnnotatorTest,FindsDurationWithDanglingQuantity)438 TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
439   const UnicodeText text = UTF8ToUnicodeText("20 minutes 10");
440   std::vector<Token> tokens = Tokenize(text);
441   std::vector<AnnotatedSpan> result;
442   EXPECT_TRUE(duration_annotator_.FindAll(
443       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
444 
445   EXPECT_THAT(
446       result,
447       ElementsAre(
448           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 13)),
449                 Field(&AnnotatedSpan::classification,
450                       ElementsAre(AllOf(
451                           Field(&ClassificationResult::collection, "duration"),
452                           Field(&ClassificationResult::duration_ms,
453                                 20 * 60 * 1000 + 10 * 1000)))))));
454 }
455 
TEST_F(DurationAnnotatorTest,FindsDurationWithDanglingQuantityNotSupported)456 TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
457   const UnicodeText text = UTF8ToUnicodeText("20 seconds 10");
458   std::vector<Token> tokens = Tokenize(text);
459   std::vector<AnnotatedSpan> result;
460   EXPECT_TRUE(duration_annotator_.FindAll(
461       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
462 
463   EXPECT_THAT(
464       result,
465       ElementsAre(AllOf(
466           Field(&AnnotatedSpan::span, CodepointSpan(0, 10)),
467           Field(&AnnotatedSpan::classification,
468                 ElementsAre(AllOf(
469                     Field(&ClassificationResult::collection, "duration"),
470                     Field(&ClassificationResult::duration_ms, 20 * 1000)))))));
471 }
472 
TEST_F(DurationAnnotatorTest,FindsDurationWithDecimalQuantity)473 TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
474   const UnicodeText text = UTF8ToUnicodeText("in 10.2 hours");
475   std::vector<Token> tokens = Tokenize(text);
476   std::vector<AnnotatedSpan> result;
477   EXPECT_TRUE(duration_annotator_.FindAll(
478       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
479 
480   EXPECT_THAT(
481       result,
482       ElementsAre(
483           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(3, 13)),
484                 Field(&AnnotatedSpan::classification,
485                       ElementsAre(AllOf(
486                           Field(&ClassificationResult::collection, "duration"),
487                           Field(&ClassificationResult::duration_ms,
488                                 10 * 60 * 60 * 1000 + 12 * 60 * 1000)))))));
489 }
490 
TestingJapaneseDurationAnnotatorOptions()491 const DurationAnnotatorOptions* TestingJapaneseDurationAnnotatorOptions() {
492   static const flatbuffers::DetachedBuffer* options_data = []() {
493     DurationAnnotatorOptionsT options;
494     options.enabled = true;
495 
496     options.week_expressions.push_back("週間");
497 
498     options.day_expressions.push_back("日間");
499 
500     options.hour_expressions.push_back("時間");
501 
502     options.minute_expressions.push_back("分");
503     options.minute_expressions.push_back("分間");
504 
505     options.second_expressions.push_back("秒");
506     options.second_expressions.push_back("秒間");
507 
508     options.half_expressions.push_back("半");
509 
510     options.require_quantity = true;
511     options.enable_dangling_quantity_interpretation = true;
512 
513     flatbuffers::FlatBufferBuilder builder;
514     builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
515     return new flatbuffers::DetachedBuffer(builder.Release());
516   }();
517 
518   return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
519 }
520 
521 class JapaneseDurationAnnotatorTest : public ::testing::Test {
522  protected:
JapaneseDurationAnnotatorTest()523   JapaneseDurationAnnotatorTest()
524       : INIT_UNILIB_FOR_TESTING(unilib_),
525         feature_processor_(BuildFeatureProcessor(&unilib_)),
526         duration_annotator_(TestingJapaneseDurationAnnotatorOptions(),
527                             feature_processor_.get(), &unilib_) {}
528 
Tokenize(const UnicodeText & text)529   std::vector<Token> Tokenize(const UnicodeText& text) {
530     return feature_processor_->Tokenize(text);
531   }
532 
533   UniLib unilib_;
534   std::unique_ptr<FeatureProcessor> feature_processor_;
535   DurationAnnotator duration_annotator_;
536 };
537 
TEST_F(JapaneseDurationAnnotatorTest,FindsDuration)538 TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
539   const UnicodeText text = UTF8ToUnicodeText("10 分 の アラーム");
540   std::vector<Token> tokens = Tokenize(text);
541   std::vector<AnnotatedSpan> result;
542   EXPECT_TRUE(duration_annotator_.FindAll(
543       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
544 
545   EXPECT_THAT(
546       result,
547       ElementsAre(
548           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 4)),
549                 Field(&AnnotatedSpan::classification,
550                       ElementsAre(AllOf(
551                           Field(&ClassificationResult::collection, "duration"),
552                           Field(&ClassificationResult::duration_ms,
553                                 10 * 60 * 1000)))))));
554 }
555 
TEST_F(JapaneseDurationAnnotatorTest,FindsDurationWithHalfExpression)556 TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
557   const UnicodeText text = UTF8ToUnicodeText("2 分 半 の アラーム");
558   std::vector<Token> tokens = Tokenize(text);
559   std::vector<AnnotatedSpan> result;
560   EXPECT_TRUE(duration_annotator_.FindAll(
561       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
562 
563   EXPECT_THAT(
564       result,
565       ElementsAre(
566           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 5)),
567                 Field(&AnnotatedSpan::classification,
568                       ElementsAre(AllOf(
569                           Field(&ClassificationResult::collection, "duration"),
570                           Field(&ClassificationResult::duration_ms,
571                                 2.5 * 60 * 1000)))))));
572 }
573 
TEST_F(JapaneseDurationAnnotatorTest,IgnoresDurationWithoutQuantity)574 TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
575   const UnicodeText text = UTF8ToUnicodeText("分 の アラーム");
576   std::vector<Token> tokens = Tokenize(text);
577   std::vector<AnnotatedSpan> result;
578   EXPECT_TRUE(duration_annotator_.FindAll(
579       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
580 
581   EXPECT_THAT(result, IsEmpty());
582 }
583 
TEST_F(JapaneseDurationAnnotatorTest,FindsDurationWithDanglingQuantity)584 TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
585   const UnicodeText text = UTF8ToUnicodeText("2 分 10 の アラーム");
586   std::vector<Token> tokens = Tokenize(text);
587   std::vector<AnnotatedSpan> result;
588   EXPECT_TRUE(duration_annotator_.FindAll(
589       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
590 
591   EXPECT_THAT(
592       result,
593       ElementsAre(
594           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 6)),
595                 Field(&AnnotatedSpan::classification,
596                       ElementsAre(AllOf(
597                           Field(&ClassificationResult::collection, "duration"),
598                           Field(&ClassificationResult::duration_ms,
599                                 2 * 60 * 1000 + 10 * 1000)))))));
600 }
601 
602 }  // namespace
603 }  // namespace libtextclassifier3
604