1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/lua-actions.h"
18
19 #include <map>
20 #include <string>
21
22 #include "actions/test_utils.h"
23 #include "actions/types.h"
24 #include "utils/tflite-model-executor.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27
28 namespace libtextclassifier3 {
29 namespace {
30
31 MATCHER_P2(IsAction, type, response_text, "") {
32 return testing::Value(arg.type, type) &&
33 testing::Value(arg.response_text, response_text);
34 }
35
36 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
37
TEST(LuaActions,SimpleAction)38 TEST(LuaActions, SimpleAction) {
39 Conversation conversation;
40 const std::string test_snippet = R"(
41 return {{ type = "test_action" }}
42 )";
43 std::vector<ActionSuggestion> actions;
44 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
45 test_snippet, conversation,
46 /*model_executor=*/nullptr,
47 /*model_spec=*/nullptr,
48 /*interpreter=*/nullptr,
49 /*actions_entity_data_schema=*/nullptr,
50 /*annotations_entity_data_schema=*/nullptr)
51 ->SuggestActions(&actions));
52 EXPECT_THAT(actions,
53 testing::ElementsAreArray({IsActionType("test_action")}));
54 }
55
TEST(LuaActions,ConversationActions)56 TEST(LuaActions, ConversationActions) {
57 Conversation conversation;
58 conversation.messages.push_back({/*user_id=*/0, "hello there!"});
59 conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
60 const std::string test_snippet = R"(
61 local actions = {}
62 for i, message in pairs(messages) do
63 if i < #messages then
64 if message.text == "hello there!" and
65 messages[i+1].text == "general kenobi!" then
66 table.insert(actions, {
67 type = "text_reply",
68 response_text = "you are a bold one!"
69 })
70 end
71 if message.text == "i am the senate!" and
72 messages[i+1].text == "not yet!" then
73 table.insert(actions, {
74 type = "text_reply",
75 response_text = "it's treason then"
76 })
77 end
78 end
79 end
80 return actions;
81 )";
82 std::vector<ActionSuggestion> actions;
83 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
84 test_snippet, conversation,
85 /*model_executor=*/nullptr,
86 /*model_spec=*/nullptr,
87 /*interpreter=*/nullptr,
88 /*actions_entity_data_schema=*/nullptr,
89 /*annotations_entity_data_schema=*/nullptr)
90 ->SuggestActions(&actions));
91 EXPECT_THAT(actions, testing::ElementsAreArray(
92 {IsAction("text_reply", "you are a bold one!")}));
93 }
94
95 TEST(LuaActions, SimpleModelAction) {
96 Conversation conversation;
97 const std::string test_snippet = R"(
98 if #model.actions_scores == 0 then
99 return {{ type = "test_action" }}
100 end
101 return {}
102 )";
103 std::vector<ActionSuggestion> actions;
104 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
105 test_snippet, conversation,
106 /*model_executor=*/nullptr,
107 /*model_spec=*/nullptr,
108 /*interpreter=*/nullptr,
109 /*actions_entity_data_schema=*/nullptr,
110 /*annotations_entity_data_schema=*/nullptr)
111 ->SuggestActions(&actions));
112 EXPECT_THAT(actions,
113 testing::ElementsAreArray({IsActionType("test_action")}));
114 }
115
116 TEST(LuaActions, AnnotationActions) {
117 AnnotatedSpan annotation;
118 annotation.span = {11, 15};
119 annotation.classification = {ClassificationResult("address", 1.0)};
120 Conversation conversation = {{{/*user_id=*/1, "are you at home?",
121 /*reference_time_ms_utc=*/0,
122 /*reference_timezone=*/"Europe/Zurich",
123 /*annotations=*/{annotation},
124 /*locales=*/"en"}}};
125 const std::string test_snippet = R"(
126 local actions = {}
127 local last_message = messages[#messages]
128 for i, annotation in pairs(last_message.annotation) do
129 if #annotation.classification > 0 then
130 if annotation.classification[1].collection == "address" then
131 local text = string.sub(last_message.text,
132 annotation.span["begin"] + 1,
133 annotation.span["end"])
134 table.insert(actions, {
135 type = "text_reply",
136 response_text = "i am at " .. text,
137 annotation = {{
138 name = "location",
139 span = {
140 text = text
141 },
142 entity = annotation.classification[1]
143 }},
144 })
145 end
146 end
147 end
148 return actions;
149 )";
150 std::vector<ActionSuggestion> actions;
151 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
152 test_snippet, conversation,
153 /*model_executor=*/nullptr,
154 /*model_spec=*/nullptr,
155 /*interpreter=*/nullptr,
156 /*actions_entity_data_schema=*/nullptr,
157 /*annotations_entity_data_schema=*/nullptr)
158 ->SuggestActions(&actions));
159 EXPECT_THAT(actions, testing::ElementsAreArray(
160 {IsAction("text_reply", "i am at home")}));
161 EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
162 }
163
164 TEST(LuaActions, EntityData) {
165 std::string test_schema = TestEntityDataSchema();
166 Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
167 const std::string test_snippet = R"(
168 return {{
169 type = "test",
170 entity = {
171 greeting = "hello",
172 location = "there",
173 person = "Kenobi",
174 },
175 }};
176 )";
177 std::vector<ActionSuggestion> actions;
178 EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
179 test_snippet, conversation,
180 /*model_executor=*/nullptr,
181 /*model_spec=*/nullptr,
182 /*interpreter=*/nullptr,
183 /*actions_entity_data_schema=*/
184 flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
185 /*annotations_entity_data_schema=*/nullptr)
186 ->SuggestActions(&actions));
187 EXPECT_THAT(actions, testing::SizeIs(1));
188 EXPECT_EQ("test", actions.front().type);
189 const flatbuffers::Table* entity =
190 flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
191 actions.front().serialized_entity_data.data()));
192 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
193 "hello");
194 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
195 "there");
196 EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
197 "Kenobi");
198 }
199
200 } // namespace
201 } // namespace libtextclassifier3
202