1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/lua-actions.h"
18
19 #include <map>
20 #include <string>
21
22 #include "actions/test-utils.h"
23 #include "actions/types.h"
24 #include "utils/tflite-model-executor.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27
28 namespace libtextclassifier3 {
29 namespace {
30
31 using testing::ElementsAre;
32
TEST(LuaActions,SimpleAction)33 TEST(LuaActions, SimpleAction) {
34 Conversation conversation;
35 const std::string test_snippet = R"(
36 return {{ type = "test_action" }}
37 )";
38 std::vector<ActionSuggestion> actions;
39 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
40 test_snippet, conversation,
41 /*model_executor=*/nullptr,
42 /*model_spec=*/nullptr,
43 /*interpreter=*/nullptr,
44 /*actions_entity_data_schema=*/nullptr,
45 /*annotations_entity_data_schema=*/nullptr)
46 ->SuggestActions(&actions));
47 EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
48 }
49
TEST(LuaActions,ConversationActions)50 TEST(LuaActions, ConversationActions) {
51 Conversation conversation;
52 conversation.messages.push_back({/*user_id=*/0, "hello there!"});
53 conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
54 const std::string test_snippet = R"(
55 local actions = {}
56 for i, message in pairs(messages) do
57 if i < #messages then
58 if message.text == "hello there!" and
59 messages[i+1].text == "general kenobi!" then
60 table.insert(actions, {
61 type = "text_reply",
62 response_text = "you are a bold one!"
63 })
64 end
65 if message.text == "i am the senate!" and
66 messages[i+1].text == "not yet!" then
67 table.insert(actions, {
68 type = "text_reply",
69 response_text = "it's treason then"
70 })
71 end
72 end
73 end
74 return actions;
75 )";
76 std::vector<ActionSuggestion> actions;
77 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
78 test_snippet, conversation,
79 /*model_executor=*/nullptr,
80 /*model_spec=*/nullptr,
81 /*interpreter=*/nullptr,
82 /*actions_entity_data_schema=*/nullptr,
83 /*annotations_entity_data_schema=*/nullptr)
84 ->SuggestActions(&actions));
85 EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
86 }
87
88 TEST(LuaActions, SimpleModelAction) {
89 Conversation conversation;
90 const std::string test_snippet = R"(
91 if #model.actions_scores == 0 then
92 return {{ type = "test_action" }}
93 end
94 return {}
95 )";
96 std::vector<ActionSuggestion> actions;
97 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
98 test_snippet, conversation,
99 /*model_executor=*/nullptr,
100 /*model_spec=*/nullptr,
101 /*interpreter=*/nullptr,
102 /*actions_entity_data_schema=*/nullptr,
103 /*annotations_entity_data_schema=*/nullptr)
104 ->SuggestActions(&actions));
105 EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
106 }
107
108 TEST(LuaActions, SimpleModelRepliesAction) {
109 Conversation conversation;
110 const std::string test_snippet = R"(
111 if #model.reply == 0 then
112 return {{ type = "test_action" }}
113 end
114 return {}
115 )";
116 std::vector<ActionSuggestion> actions;
117 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
118 test_snippet, conversation,
119 /*model_executor=*/nullptr,
120 /*model_spec=*/nullptr,
121 /*interpreter=*/nullptr,
122 /*actions_entity_data_schema=*/nullptr,
123 /*annotations_entity_data_schema=*/nullptr)
124 ->SuggestActions(&actions));
125 EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
126 }
127
128 TEST(LuaActions, AnnotationActions) {
129 AnnotatedSpan annotation;
130 annotation.span = {11, 15};
131 annotation.classification = {ClassificationResult("address", 1.0)};
132 Conversation conversation = {{{/*user_id=*/1, "are you at home?",
133 /*reference_time_ms_utc=*/0,
134 /*reference_timezone=*/"Europe/Zurich",
135 /*annotations=*/{annotation},
136 /*locales=*/"en"}}};
137 const std::string test_snippet = R"(
138 local actions = {}
139 local last_message = messages[#messages]
140 for i, annotation in pairs(last_message.annotation) do
141 if #annotation.classification > 0 then
142 if annotation.classification[1].collection == "address" then
143 local text = string.sub(last_message.text,
144 annotation.span["begin"] + 1,
145 annotation.span["end"])
146 table.insert(actions, {
147 type = "text_reply",
148 response_text = "i am at " .. text,
149 annotation = {{
150 name = "location",
151 span = {
152 text = text
153 },
154 entity = annotation.classification[1]
155 }},
156 })
157 end
158 end
159 end
160 return actions;
161 )";
162 std::vector<ActionSuggestion> actions;
163 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
164 test_snippet, conversation,
165 /*model_executor=*/nullptr,
166 /*model_spec=*/nullptr,
167 /*interpreter=*/nullptr,
168 /*actions_entity_data_schema=*/nullptr,
169 /*annotations_entity_data_schema=*/nullptr)
170 ->SuggestActions(&actions));
171 EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
172 EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
173 }
174
175 TEST(LuaActions, EntityData) {
176 std::string test_schema = TestEntityDataSchema();
177 Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
178 const std::string test_snippet = R"(
179 return {{
180 type = "test",
181 entity = {
182 greeting = "hello",
183 location = "there",
184 person = "Kenobi",
185 },
186 }};
187 )";
188 std::vector<ActionSuggestion> actions;
189 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
190 test_snippet, conversation,
191 /*model_executor=*/nullptr,
192 /*model_spec=*/nullptr,
193 /*interpreter=*/nullptr,
194 /*actions_entity_data_schema=*/
195 flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
196 /*annotations_entity_data_schema=*/nullptr)
197 ->SuggestActions(&actions));
198 EXPECT_THAT(actions, testing::SizeIs(1));
199 EXPECT_EQ("test", actions.front().type);
200 const flatbuffers::Table* entity =
201 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
202 actions.front().serialized_entity_data.data()));
203 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
204 "hello");
205 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
206 "there");
207 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
208 "Kenobi");
209 }
210
211 } // namespace
212 } // namespace libtextclassifier3
213