• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/lua-actions.h"
18 
19 #include <map>
20 #include <string>
21 
22 #include "actions/test_utils.h"
23 #include "actions/types.h"
24 #include "utils/tflite-model-executor.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 
28 namespace libtextclassifier3 {
29 namespace {
30 
31 MATCHER_P2(IsAction, type, response_text, "") {
32   return testing::Value(arg.type, type) &&
33          testing::Value(arg.response_text, response_text);
34 }
35 
36 MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
37 
TEST(LuaActions,SimpleAction)38 TEST(LuaActions, SimpleAction) {
39   Conversation conversation;
40   const std::string test_snippet = R"(
41     return {{ type = "test_action" }}
42   )";
43   std::vector<ActionSuggestion> actions;
44   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
45                   test_snippet, conversation,
46                   /*model_executor=*/nullptr,
47                   /*model_spec=*/nullptr,
48                   /*interpreter=*/nullptr,
49                   /*actions_entity_data_schema=*/nullptr,
50                   /*annotations_entity_data_schema=*/nullptr)
51                   ->SuggestActions(&actions));
52   EXPECT_THAT(actions,
53               testing::ElementsAreArray({IsActionType("test_action")}));
54 }
55 
TEST(LuaActions,ConversationActions)56 TEST(LuaActions, ConversationActions) {
57   Conversation conversation;
58   conversation.messages.push_back({/*user_id=*/0, "hello there!"});
59   conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
60   const std::string test_snippet = R"(
61     local actions = {}
62     for i, message in pairs(messages) do
63       if i < #messages then
64         if message.text == "hello there!" and
65            messages[i+1].text == "general kenobi!" then
66            table.insert(actions, {
67              type = "text_reply",
68              response_text = "you are a bold one!"
69            })
70         end
71         if message.text == "i am the senate!" and
72            messages[i+1].text == "not yet!" then
73            table.insert(actions, {
74              type = "text_reply",
75              response_text = "it's treason then"
76            })
77         end
78       end
79     end
80     return actions;
81   )";
82   std::vector<ActionSuggestion> actions;
83   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
84                   test_snippet, conversation,
85                   /*model_executor=*/nullptr,
86                   /*model_spec=*/nullptr,
87                   /*interpreter=*/nullptr,
88                   /*actions_entity_data_schema=*/nullptr,
89                   /*annotations_entity_data_schema=*/nullptr)
90                   ->SuggestActions(&actions));
91   EXPECT_THAT(actions, testing::ElementsAreArray(
92                            {IsAction("text_reply", "you are a bold one!")}));
93 }
94 
95 TEST(LuaActions, SimpleModelAction) {
96   Conversation conversation;
97   const std::string test_snippet = R"(
98     if #model.actions_scores == 0 then
99       return {{ type = "test_action" }}
100     end
101     return {}
102   )";
103   std::vector<ActionSuggestion> actions;
104   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
105                   test_snippet, conversation,
106                   /*model_executor=*/nullptr,
107                   /*model_spec=*/nullptr,
108                   /*interpreter=*/nullptr,
109                   /*actions_entity_data_schema=*/nullptr,
110                   /*annotations_entity_data_schema=*/nullptr)
111                   ->SuggestActions(&actions));
112   EXPECT_THAT(actions,
113               testing::ElementsAreArray({IsActionType("test_action")}));
114 }
115 
116 TEST(LuaActions, AnnotationActions) {
117   AnnotatedSpan annotation;
118   annotation.span = {11, 15};
119   annotation.classification = {ClassificationResult("address", 1.0)};
120   Conversation conversation = {{{/*user_id=*/1, "are you at home?",
121                                  /*reference_time_ms_utc=*/0,
122                                  /*reference_timezone=*/"Europe/Zurich",
123                                  /*annotations=*/{annotation},
124                                  /*locales=*/"en"}}};
125   const std::string test_snippet = R"(
126     local actions = {}
127     local last_message = messages[#messages]
128     for i, annotation in pairs(last_message.annotation) do
129       if #annotation.classification > 0 then
130         if annotation.classification[1].collection == "address" then
131            local text = string.sub(last_message.text,
132                             annotation.span["begin"] + 1,
133                             annotation.span["end"])
134            table.insert(actions, {
135              type = "text_reply",
136              response_text = "i am at " .. text,
137              annotation = {{
138                name = "location",
139                span = {
140                  text = text
141                },
142                entity = annotation.classification[1]
143              }},
144            })
145         end
146       end
147     end
148     return actions;
149   )";
150   std::vector<ActionSuggestion> actions;
151   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
152                   test_snippet, conversation,
153                   /*model_executor=*/nullptr,
154                   /*model_spec=*/nullptr,
155                   /*interpreter=*/nullptr,
156                   /*actions_entity_data_schema=*/nullptr,
157                   /*annotations_entity_data_schema=*/nullptr)
158                   ->SuggestActions(&actions));
159   EXPECT_THAT(actions, testing::ElementsAreArray(
160                            {IsAction("text_reply", "i am at home")}));
161   EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
162 }
163 
164 TEST(LuaActions, EntityData) {
165   std::string test_schema = TestEntityDataSchema();
166   Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
167   const std::string test_snippet = R"(
168     return {{
169       type = "test",
170       entity = {
171         greeting = "hello",
172         location = "there",
173         person = "Kenobi",
174       },
175     }};
176   )";
177   std::vector<ActionSuggestion> actions;
178   EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
179                   test_snippet, conversation,
180                   /*model_executor=*/nullptr,
181                   /*model_spec=*/nullptr,
182                   /*interpreter=*/nullptr,
183                   /*actions_entity_data_schema=*/
184                   flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
185                   /*annotations_entity_data_schema=*/nullptr)
186                   ->SuggestActions(&actions));
187   EXPECT_THAT(actions, testing::SizeIs(1));
188   EXPECT_EQ("test", actions.front().type);
189   const flatbuffers::Table* entity =
190       flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
191           actions.front().serialized_entity_data.data()));
192   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
193             "hello");
194   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
195             "there");
196   EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
197             "Kenobi");
198 }
199 
200 }  // namespace
201 }  // namespace libtextclassifier3
202