• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-actions.h"
18 
19 #include <map>
20 #include <string>
21 
22 #include "actions/test-utils.h"
23 #include "actions/types.h"
24 #include "utils/tflite-model-executor.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 
28 namespace libtextclassifier3 {
29 namespace {
30 
31 using testing::ElementsAre;
32 
TEST(LuaActions,SimpleAction)33 TEST(LuaActions, SimpleAction) {
34   Conversation conversation;
35   const std::string test_snippet = R"(
36     return {{ type = "test_action" }}
37   )";
38   std::vector<ActionSuggestion> actions;
39   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
40                   test_snippet, conversation,
41                   /*model_executor=*/nullptr,
42                   /*model_spec=*/nullptr,
43                   /*interpreter=*/nullptr,
44                   /*actions_entity_data_schema=*/nullptr,
45                   /*annotations_entity_data_schema=*/nullptr)
46                   ->SuggestActions(&actions));
47   EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
48 }
49 
TEST(LuaActions,ConversationActions)50 TEST(LuaActions, ConversationActions) {
51   Conversation conversation;
52   conversation.messages.push_back({/*user_id=*/0, "hello there!"});
53   conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
54   const std::string test_snippet = R"(
55     local actions = {}
56     for i, message in pairs(messages) do
57       if i < #messages then
58         if message.text == "hello there!" and
59            messages[i+1].text == "general kenobi!" then
60            table.insert(actions, {
61              type = "text_reply",
62              response_text = "you are a bold one!"
63            })
64         end
65         if message.text == "i am the senate!" and
66            messages[i+1].text == "not yet!" then
67            table.insert(actions, {
68              type = "text_reply",
69              response_text = "it's treason then"
70            })
71         end
72       end
73     end
74     return actions;
75   )";
76   std::vector<ActionSuggestion> actions;
77   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
78                   test_snippet, conversation,
79                   /*model_executor=*/nullptr,
80                   /*model_spec=*/nullptr,
81                   /*interpreter=*/nullptr,
82                   /*actions_entity_data_schema=*/nullptr,
83                   /*annotations_entity_data_schema=*/nullptr)
84                   ->SuggestActions(&actions));
85   EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
86 }
87 
88 TEST(LuaActions, SimpleModelAction) {
89   Conversation conversation;
90   const std::string test_snippet = R"(
91     if #model.actions_scores == 0 then
92       return {{ type = "test_action" }}
93     end
94     return {}
95   )";
96   std::vector<ActionSuggestion> actions;
97   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
98                   test_snippet, conversation,
99                   /*model_executor=*/nullptr,
100                   /*model_spec=*/nullptr,
101                   /*interpreter=*/nullptr,
102                   /*actions_entity_data_schema=*/nullptr,
103                   /*annotations_entity_data_schema=*/nullptr)
104                   ->SuggestActions(&actions));
105   EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
106 }
107 
108 TEST(LuaActions, SimpleModelRepliesAction) {
109   Conversation conversation;
110   const std::string test_snippet = R"(
111     if #model.reply == 0 then
112       return {{ type = "test_action" }}
113     end
114     return {}
115   )";
116   std::vector<ActionSuggestion> actions;
117   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
118                   test_snippet, conversation,
119                   /*model_executor=*/nullptr,
120                   /*model_spec=*/nullptr,
121                   /*interpreter=*/nullptr,
122                   /*actions_entity_data_schema=*/nullptr,
123                   /*annotations_entity_data_schema=*/nullptr)
124                   ->SuggestActions(&actions));
125   EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
126 }
127 
128 TEST(LuaActions, AnnotationActions) {
129   AnnotatedSpan annotation;
130   annotation.span = {11, 15};
131   annotation.classification = {ClassificationResult("address", 1.0)};
132   Conversation conversation = {{{/*user_id=*/1, "are you at home?",
133                                  /*reference_time_ms_utc=*/0,
134                                  /*reference_timezone=*/"Europe/Zurich",
135                                  /*annotations=*/{annotation},
136                                  /*locales=*/"en"}}};
137   const std::string test_snippet = R"(
138     local actions = {}
139     local last_message = messages[#messages]
140     for i, annotation in pairs(last_message.annotation) do
141       if #annotation.classification > 0 then
142         if annotation.classification[1].collection == "address" then
143            local text = string.sub(last_message.text,
144                             annotation.span["begin"] + 1,
145                             annotation.span["end"])
146            table.insert(actions, {
147              type = "text_reply",
148              response_text = "i am at " .. text,
149              annotation = {{
150                name = "location",
151                span = {
152                  text = text
153                },
154                entity = annotation.classification[1]
155              }},
156            })
157         end
158       end
159     end
160     return actions;
161   )";
162   std::vector<ActionSuggestion> actions;
163   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
164                   test_snippet, conversation,
165                   /*model_executor=*/nullptr,
166                   /*model_spec=*/nullptr,
167                   /*interpreter=*/nullptr,
168                   /*actions_entity_data_schema=*/nullptr,
169                   /*annotations_entity_data_schema=*/nullptr)
170                   ->SuggestActions(&actions));
171   EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
172   EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
173 }
174 
175 TEST(LuaActions, EntityData) {
176   std::string test_schema = TestEntityDataSchema();
177   Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
178   const std::string test_snippet = R"(
179     return {{
180       type = "test",
181       entity = {
182         greeting = "hello",
183         location = "there",
184         person = "Kenobi",
185       },
186     }};
187   )";
188   std::vector<ActionSuggestion> actions;
189   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
190                   test_snippet, conversation,
191                   /*model_executor=*/nullptr,
192                   /*model_spec=*/nullptr,
193                   /*interpreter=*/nullptr,
194                   /*actions_entity_data_schema=*/
195                   flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
196                   /*annotations_entity_data_schema=*/nullptr)
197                   ->SuggestActions(&actions));
198   EXPECT_THAT(actions, testing::SizeIs(1));
199   EXPECT_EQ("test", actions.front().type);
200   const flatbuffers::Table* entity =
201       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
202           actions.front().serialized_entity_data.data()));
203   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
204             "hello");
205   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
206             "there");
207   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
208             "Kenobi");
209 }
210 
211 }  // namespace
212 }  // namespace libtextclassifier3
213